gopidesupavan commented on PR #62963:
URL: https://github.com/apache/airflow/pull/62963#issuecomment-4274506838

   and here is just pseudo diff
   
   
   ```
   diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/utils/dq_models.py 
b/providers/common/ai/src/airflow/providers/common/ai/utils/dq_models.py
   @@
   +from typing import Any, Literal
   +
   +class ValidationSpec(BaseModel):
   +    kind: Literal[
   +        "null_pct",
   +        "row_count_min",
   +        "duplicate_pct",
   +        "exact_value",
   +        "numeric_between",
   +        "regex_row_level",
   +    ]
   +    params: dict[str, Any] = Field(default_factory=dict)
   +
    class DQCheck(BaseModel):
   -    check_name: str
   +    name: str
   +    description: str
        metric_key: str
        group_id: str
        check_category: str = ""
   +    validation: ValidationSpec
        unexpected_query: str | None = None
        row_level: bool = False
   
   diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/dq_validation.py 
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/dq_validation.py
   new file mode 100644
   @@
   +from __future__ import annotations
   +import json
   +from typing import Any
   +from pydantic_ai.tools import ToolDefinition
   +from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
   +from pydantic_core import SchemaValidator, core_schema
   +
   +_ARGS = SchemaValidator(core_schema.any_schema())
   +
   +class DQValidationToolset(AbstractToolset[Any]):
   +    @property
   +    def id(self) -> str:
   +        return "dq-validation"
   +
   +    async def get_tools(self, ctx):
   +        defs = [
   +            ToolDefinition(
   +                name="null_pct_check",
   +                description="Null or missing percentage check. Returns 
validation kind + params.",
   +                parameters_json_schema={"type": "object", "properties": 
{"max_pct": {"type": "number"}}, "required": ["max_pct"]},
   +                sequential=True,
   +            ),
   +            ToolDefinition(
   +                name="row_count_min_check",
   +                description="Minimum row count check.",
   +                parameters_json_schema={"type": "object", "properties": 
{"min_count": {"type": "integer"}}, "required": ["min_count"]},
   +                sequential=True,
   +            ),
   +            ToolDefinition(
   +                name="duplicate_pct_check",
   +                description="Duplicate percentage check.",
   +                parameters_json_schema={"type": "object", "properties": 
{"max_pct": {"type": "number"}}, "required": ["max_pct"]},
   +                sequential=True,
   +            ),
   +            ToolDefinition(
   +                name="regex_row_level_check",
   +                description="Row-level regex validation with max invalid 
percentage.",
   +                parameters_json_schema={"type": "object", "properties": 
{"pattern": {"type": "string"}, "max_invalid_pct": {"type": "number"}}, 
"required": ["pattern", "max_invalid_pct"]},
   +                sequential=True,
   +            ),
   +        ]
   +        return {
   +            d.name: ToolsetTool(toolset=self, tool_def=d, max_retries=1, 
args_validator=_ARGS)
   +            for d in defs
   +        }
   +
   +    async def call_tool(self, name, tool_args, ctx, tool):
   +        mapping = {
   +            "null_pct_check": {"kind": "null_pct", "check_category": 
"null_check", "row_level": False},
   +            "row_count_min_check": {"kind": "row_count_min", 
"check_category": "row_count", "row_level": False},
   +            "duplicate_pct_check": {"kind": "duplicate_pct", 
"check_category": "uniqueness", "row_level": False},
   +            "regex_row_level_check": {"kind": "regex_row_level", 
"check_category": "row_level", "row_level": True},
   +        }
   +        return json.dumps({"kind": mapping[name]["kind"], "params": 
tool_args, "check_category": mapping[name]["check_category"], "row_level": 
mapping[name]["row_level"]})
   
   
   diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/utils/dq_planner.py 
b/providers/common/ai/src/airflow/providers/common/ai/utils/dq_planner.py
   @@
   -from airflow.providers.common.ai.utils.dq_models import DQCheckGroup, 
DQPlan, RowLevelResult, UnexpectedResult
   +from airflow.providers.common.ai.utils.dq_models import DQCheckGroup, 
DQPlan, RowLevelResult, UnexpectedResult
   +from airflow.providers.common.ai.toolsets.dq_validation import 
DQValidationToolset
   +from airflow.providers.common.ai.toolsets.sql import SQLToolset
   @@
   -    def generate_plan(self, prompts: dict[str, str], schema_context: str) 
-> DQPlan:
   +    def generate_plan(self, checks: list[dict[str, str]], schema_context: 
str) -> DQPlan:
   @@
   -        system_prompt = _PLANNING_SYSTEM_PROMPT.format(...)
   +        system_prompt = """
   +You are a SQL data-quality planner.
   +For each requested check:
   +1. Inspect schema with SQL tools if needed.
   +2. Call exactly one DQ validation tool to choose the validation kind.
   +3. Generate SQL that returns a metric matching that validation kind.
   +4. Group by (table, check_category).
   +5. Return only a DQPlan.
   +"""
   @@
   -        user_message = self._build_user_message(prompts)
   +        user_message = self._build_user_message(checks)
   @@
   -        agent = self._llm_hook.create_agent(output_type=DQPlan, 
instructions=system_prompt, **self._agent_params)
   +        agent = self._llm_hook.create_agent(
   +            output_type=DQPlan,
   +            instructions=system_prompt,
   +            toolsets=[
   +                SQLToolset(db_conn_id=self._db_hook.conn_name_attr, 
allowed_tables=_extract_table_names(schema_context) or None),
   +                DQValidationToolset(),
   +            ],
   +            **self._agent_params,
   +        )
   @@
   -    def _build_user_message(prompts: dict[str, str]) -> str:
   -        for check_name, description in prompts.items():
   -            lines.append(f'  - check_name="{check_name}": {description}')
   +    def _build_user_message(checks: list[dict[str, str]]) -> str:
   +        for check in checks:
   +            lines.append(f'  - name="{check["name"]}": 
{check["description"]}')
   
   
   diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_data_quality.py
 
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_data_quality.py
   @@
   -from airflow.providers.common.ai.utils.dq_validation import default_registry
   @@
   -        prompts: dict[str, str],
   +        checks: list[dict[str, str]],
   @@
   -        validators: dict[str, Callable[[Any], bool]] | None = None,
            dialect: str | None = None,
   @@
   -        self.prompts = prompts
   -        self.validators = validators or {}
   +        self.checks = checks
   @@
   -        self._validate_prompts()
   -        self._validate_validator_keys()
   +        self._validate_checks()
   @@
   -        plan = self._load_or_generate_plan(planner, schema_ctx)
   +        plan = self._load_or_generate_plan(planner, schema_ctx)
   @@
   -        return default_registry.build_llm_context(self.validators)
   +        return ""
   @@
   -        plan_hash = _compute_plan_hash(self.prompts, ...)
   +        plan_hash = _compute_plan_hash(self.checks, ...)
   @@
   -        plan = planner.generate_plan(self.prompts, schema_ctx)
   +        plan = planner.generate_plan(self.checks, schema_ctx)
   @@
   -                validator = self.validators.get(check.check_name)
   -                ...
   -                passed = bool(validator(value))
   +                passed = self._apply_validation(value, check.validation)
   @@
   +    def _apply_validation(self, value: Any, spec) -> bool:
   +        if spec.kind == "null_pct":
   +            return float(value) <= float(spec.params["max_pct"])
   +        if spec.kind == "row_count_min":
   +            return int(value) >= int(spec.params["min_count"])
   +        if spec.kind == "duplicate_pct":
   +            return float(value) <= float(spec.params["max_pct"])
   +        if spec.kind == "exact_value":
   +            return value == spec.params["expected"]
   +        if spec.kind == "numeric_between":
   +            return float(spec.params["min_val"]) <= float(value) <= 
float(spec.params["max_val"])
   +        raise ValueError(f"Unsupported aggregate validation kind: 
{spec.kind}")
   +
   +    def _validate_checks(self) -> None:
   +        if not self.checks:
   +            raise ValueError("checks must not be empty")
   
   
   diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_data_quality.py
 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_data_quality.py
   @@
   -    LLMDataQualityOperator(
   -        prompts={
   -            "null_order_id": "Check the percentage of rows where order_id 
is NULL",
   -            "duplicate_orders": "Calculate the percentage of duplicate 
order_id values",
   -            "min_row_count": "Count the total number of rows in the orders 
table",
   -        },
   -        validators={
   -            "null_order_id": null_pct_check(max_pct=0.0),
   -            "duplicate_orders": duplicate_pct_check(max_pct=0.0),
   -            "min_row_count": row_count_check(min_count=10_000),
   -        },
   -    )
   +    LLMDataQualityOperator(
   +        checks=[
   +            {"name": "null_order_id", "description": "order_id should never 
be null"},
   +            {"name": "duplicate_orders", "description": "duplicate order_id 
percentage should be 0%"},
   +            {"name": "min_row_count", "description": "orders table should 
have at least 10,000 rows"},
   +        ],
   +    )
   
   
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to