gopidesupavan commented on PR #62963:
URL: https://github.com/apache/airflow/pull/62963#issuecomment-4274506838
and here is just pseudo diff
```
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/utils/dq_models.py
b/providers/common/ai/src/airflow/providers/common/ai/utils/dq_models.py
@@
+from typing import Any, Literal
+
+class ValidationSpec(BaseModel):
+ kind: Literal[
+ "null_pct",
+ "row_count_min",
+ "duplicate_pct",
+ "exact_value",
+ "numeric_between",
+ "regex_row_level",
+ ]
+ params: dict[str, Any] = Field(default_factory=dict)
+
class DQCheck(BaseModel):
- check_name: str
+ name: str
+ description: str
metric_key: str
group_id: str
check_category: str = ""
+ validation: ValidationSpec
unexpected_query: str | None = None
row_level: bool = False
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/dq_validation.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/dq_validation.py
new file mode 100644
@@
+from __future__ import annotations
+import json
+from typing import Any
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+_ARGS = SchemaValidator(core_schema.any_schema())
+
+class DQValidationToolset(AbstractToolset[Any]):
+ @property
+ def id(self) -> str:
+ return "dq-validation"
+
+ async def get_tools(self, ctx):
+ defs = [
+ ToolDefinition(
+ name="null_pct_check",
+ description="Null or missing percentage check. Returns
validation kind + params.",
+ parameters_json_schema={"type": "object", "properties":
{"max_pct": {"type": "number"}}, "required": ["max_pct"]},
+ sequential=True,
+ ),
+ ToolDefinition(
+ name="row_count_min_check",
+ description="Minimum row count check.",
+ parameters_json_schema={"type": "object", "properties":
{"min_count": {"type": "integer"}}, "required": ["min_count"]},
+ sequential=True,
+ ),
+ ToolDefinition(
+ name="duplicate_pct_check",
+ description="Duplicate percentage check.",
+ parameters_json_schema={"type": "object", "properties":
{"max_pct": {"type": "number"}}, "required": ["max_pct"]},
+ sequential=True,
+ ),
+ ToolDefinition(
+ name="regex_row_level_check",
+ description="Row-level regex validation with max invalid
percentage.",
+ parameters_json_schema={"type": "object", "properties":
{"pattern": {"type": "string"}, "max_invalid_pct": {"type": "number"}},
"required": ["pattern", "max_invalid_pct"]},
+ sequential=True,
+ ),
+ ]
+ return {
+ d.name: ToolsetTool(toolset=self, tool_def=d, max_retries=1,
args_validator=_ARGS)
+ for d in defs
+ }
+
+ async def call_tool(self, name, tool_args, ctx, tool):
+ mapping = {
+ "null_pct_check": {"kind": "null_pct", "check_category":
"null_check", "row_level": False},
+ "row_count_min_check": {"kind": "row_count_min",
"check_category": "row_count", "row_level": False},
+ "duplicate_pct_check": {"kind": "duplicate_pct",
"check_category": "uniqueness", "row_level": False},
+ "regex_row_level_check": {"kind": "regex_row_level",
"check_category": "row_level", "row_level": True},
+ }
+ return json.dumps({"kind": mapping[name]["kind"], "params":
tool_args, "check_category": mapping[name]["check_category"], "row_level":
mapping[name]["row_level"]})
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/utils/dq_planner.py
b/providers/common/ai/src/airflow/providers/common/ai/utils/dq_planner.py
@@
-from airflow.providers.common.ai.utils.dq_models import DQCheckGroup,
DQPlan, RowLevelResult, UnexpectedResult
+from airflow.providers.common.ai.utils.dq_models import DQCheckGroup,
DQPlan, RowLevelResult, UnexpectedResult
+from airflow.providers.common.ai.toolsets.dq_validation import
DQValidationToolset
+from airflow.providers.common.ai.toolsets.sql import SQLToolset
@@
- def generate_plan(self, prompts: dict[str, str], schema_context: str)
-> DQPlan:
+ def generate_plan(self, checks: list[dict[str, str]], schema_context:
str) -> DQPlan:
@@
- system_prompt = _PLANNING_SYSTEM_PROMPT.format(...)
+ system_prompt = """
+You are a SQL data-quality planner.
+For each requested check:
+1. Inspect schema with SQL tools if needed.
+2. Call exactly one DQ validation tool to choose the validation kind.
+3. Generate SQL that returns a metric matching that validation kind.
+4. Group by (table, check_category).
+5. Return only a DQPlan.
+"""
@@
- user_message = self._build_user_message(prompts)
+ user_message = self._build_user_message(checks)
@@
- agent = self._llm_hook.create_agent(output_type=DQPlan,
instructions=system_prompt, **self._agent_params)
+ agent = self._llm_hook.create_agent(
+ output_type=DQPlan,
+ instructions=system_prompt,
+ toolsets=[
+ SQLToolset(db_conn_id=self._db_hook.conn_name_attr,
allowed_tables=_extract_table_names(schema_context) or None),
+ DQValidationToolset(),
+ ],
+ **self._agent_params,
+ )
@@
- def _build_user_message(prompts: dict[str, str]) -> str:
- for check_name, description in prompts.items():
- lines.append(f' - check_name="{check_name}": {description}')
+ def _build_user_message(checks: list[dict[str, str]]) -> str:
+ for check in checks:
+ lines.append(f' - name="{check["name"]}":
{check["description"]}')
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_data_quality.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_data_quality.py
@@
-from airflow.providers.common.ai.utils.dq_validation import default_registry
@@
- prompts: dict[str, str],
+ checks: list[dict[str, str]],
@@
- validators: dict[str, Callable[[Any], bool]] | None = None,
dialect: str | None = None,
@@
- self.prompts = prompts
- self.validators = validators or {}
+ self.checks = checks
@@
- self._validate_prompts()
- self._validate_validator_keys()
+ self._validate_checks()
@@
- plan = self._load_or_generate_plan(planner, schema_ctx)
+ plan = self._load_or_generate_plan(planner, schema_ctx)
@@
- return default_registry.build_llm_context(self.validators)
+ return ""
@@
- plan_hash = _compute_plan_hash(self.prompts, ...)
+ plan_hash = _compute_plan_hash(self.checks, ...)
@@
- plan = planner.generate_plan(self.prompts, schema_ctx)
+ plan = planner.generate_plan(self.checks, schema_ctx)
@@
- validator = self.validators.get(check.check_name)
- ...
- passed = bool(validator(value))
+ passed = self._apply_validation(value, check.validation)
@@
+ def _apply_validation(self, value: Any, spec) -> bool:
+ if spec.kind == "null_pct":
+ return float(value) <= float(spec.params["max_pct"])
+ if spec.kind == "row_count_min":
+ return int(value) >= int(spec.params["min_count"])
+ if spec.kind == "duplicate_pct":
+ return float(value) <= float(spec.params["max_pct"])
+ if spec.kind == "exact_value":
+ return value == spec.params["expected"]
+ if spec.kind == "numeric_between":
+ return float(spec.params["min_val"]) <= float(value) <=
float(spec.params["max_val"])
+ raise ValueError(f"Unsupported aggregate validation kind:
{spec.kind}")
+
+ def _validate_checks(self) -> None:
+ if not self.checks:
+ raise ValueError("checks must not be empty")
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_data_quality.py
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_llm_data_quality.py
@@
- LLMDataQualityOperator(
- prompts={
- "null_order_id": "Check the percentage of rows where order_id
is NULL",
- "duplicate_orders": "Calculate the percentage of duplicate
order_id values",
- "min_row_count": "Count the total number of rows in the orders
table",
- },
- validators={
- "null_order_id": null_pct_check(max_pct=0.0),
- "duplicate_orders": duplicate_pct_check(max_pct=0.0),
- "min_row_count": row_count_check(min_count=10_000),
- },
- )
+ LLMDataQualityOperator(
+ checks=[
+ {"name": "null_order_id", "description": "order_id should never
be null"},
+ {"name": "duplicate_orders", "description": "duplicate order_id
percentage should be 0%"},
+ {"name": "min_row_count", "description": "orders table should
have at least 10,000 rows"},
+ ],
+ )
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]