kaxil commented on code in PR #61794: URL: https://github.com/apache/airflow/pull/61794#discussion_r2813246239
########## providers/common/ai/docs/operators.rst: ########## @@ -0,0 +1,62 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +LLM Operators +============= + +LLM Operators are specialized Airflow operators designed to generate sql queries using Large Language Models (LLMs) based on provided prompts. +These operators leverage the capabilities DBApi Hook in airflow to get the database schema and use it to generate SQL queries. see the currently supported databases to extract schema dynamically using +existing DBApi Hooks or optionally provide the database schema in the input datasource config. + +Current LLM Operators uses the Pydantic AI framework to connect to different LLM providers. + +Supported LLM Providers +----------------------- + +- OpenAI +- Google +- Anthropic +- GitHub Review Comment: GitHub feels an odd one here -- what's that for? ########## providers/common/ai/docs/operators.rst: ########## @@ -0,0 +1,62 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +LLM Operators +============= + +LLM Operators are specialized Airflow operators designed to generate sql queries using Large Language Models (LLMs) based on provided prompts. Review Comment: This doc should be SQL agnostic -- i.e. the following may be better: ```suggestion LLM Operators are specialized Airflow operators designed to interact with LLMs in various different ways example to generate sql queries based on provided prompts. ``` ########## providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py: ########## @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +from airflow.providers.common.ai.operators.base_llm import BaseLLMOperator + +if TYPE_CHECKING: + from pydantic_ai.agent import AgentRunResult + + from airflow.providers.common.ai.configs.datasource import DataSourceConfig + + +class SQLQueryResponseOutputType(BaseModel): + """Output type LLM Sql query generate.""" + + sql_query_prompt_dict: dict[str, str] + + +class LLMSQLQueryOperator(BaseLLMOperator): + """Operator to generate SQL queries based on prompts for multiple datasources.""" + + def __init__( + self, datasource_configs: list[DataSourceConfig], provider_model: str | None = None, **kwargs + ): + super().__init__(datasource_configs=datasource_configs, **kwargs) + self.provider_model = provider_model + + def execute(self, context): + """Execute LLM Sql query operator.""" + return super().execute(context) Review Comment: Since we aren't doing anything in this method, we can nuke it and let inheritance take care of it ########## providers/common/ai/src/airflow/providers/common/ai/llm_providers/model_providers.py: ########## @@ -0,0 +1,212 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.ai.llm_providers.base import ModelProvider + +if TYPE_CHECKING: + from pydantic_ai import ModelSettings + from pydantic_ai.models import Model + + +class OpenAIModelProvider(ModelProvider): + """Model provider for OpenAI models.""" + + @property + def provider_name(self) -> str: + """Return the name of the provider.""" + return "openai" + + def get_model_settings(self, model_settings: dict[str, Any] | None = None) -> ModelSettings | None: + """Get model settings for OpenAI models.""" + from pydantic_ai.models.openai import OpenAIChatModelSettings + + if model_settings is None: + return None + + self.log.info("Model settings %s initialized for %s", model_settings, self.provider_name) + + return OpenAIChatModelSettings(**model_settings) + + def build_model(self, model_name: str, api_key: str, **kwargs) -> Model: + """Build and returns OpenAIChatModel.""" + from pydantic_ai.models.openai import OpenAIChatModel + from pydantic_ai.providers.openai import OpenAIProvider + + model_settings = self.get_model_settings(kwargs.get("model_settings")) + + return OpenAIChatModel(model_name, provider=OpenAIProvider(api_key=api_key), settings=model_settings) + + +class AnthropicModelProvider(ModelProvider): + """Model provider for Anthropic models.""" + + @property + def provider_name(self) -> str: + """Return the name of the provider.""" + return "anthropic" + + def get_model_settings(self, model_settings: dict[str, Any] | None = None) -> ModelSettings | None: + """Get model settings for Anthropic models.""" + from pydantic_ai.models.anthropic import AnthropicModelSettings + + if model_settings is None: + return None + + self.log.info("Model settings %s initialized for %s", model_settings, self.provider_name) + + return AnthropicModelSettings(**model_settings) + + def build_model(self, model_name: str, api_key: str, **kwargs) -> Model: + """Build and returns AnthropicModel.""" + from pydantic_ai.models.anthropic import AnthropicModel + from pydantic_ai.providers.anthropic import AnthropicProvider + + model_settings = self.get_model_settings(kwargs.get("model_settings")) + + model = AnthropicModel( + model_name, provider=AnthropicProvider(api_key=api_key), settings=model_settings + ) + self.log.info("Model %s initialized for provider %s", model_name, self.provider_name) + return model + + +class GoogleModelProvider(ModelProvider): + """Model provider for Google models.""" + + @property + def provider_name(self) -> str: + """Return the name of the provider.""" + return "google" + + def get_model_settings(self, model_settings: dict[str, Any] | None = None) -> ModelSettings | None: + """Get model settings for Google models.""" + from pydantic_ai.models.google import GoogleModelSettings + + if model_settings is None: + return None + + self.log.info("Model settings %s initialized for %s", model_settings, self.provider_name) + return GoogleModelSettings(**model_settings) + + def build_model(self, model_name: str, api_key: str, **kwargs) -> Model: + """Build and returns GoogleModel.""" + from pydantic_ai.models.google import GoogleModel + from pydantic_ai.providers.google import GoogleProvider + + model_settings = self.get_model_settings(kwargs.get("model_settings")) + model = GoogleModel(model_name, provider=GoogleProvider(api_key=api_key), settings=model_settings) + + self.log.info("Model %s initialized for provider %s", model_name, self.provider_name) + return model + + +def _build_open_ai_based_model(model_name, provider, **kwargs) -> Model: + """ + Create a model instance based on the provided model name and parameters. + + There are models that are compatible with OpenAI compatible modes, https://ai.pydantic.dev/models/openai/#openai-compatible-models + This function builds those models. + """ + from pydantic_ai.models.openai import OpenAIChatModel, OpenAIChatModelSettings + + settings = kwargs.get("model_settings") + if settings: + settings = OpenAIChatModelSettings(**settings) + + return OpenAIChatModel(model_name, provider=provider, settings=settings) + + +class GithubModelProvider(ModelProvider): + """Model provider for GitHub models.""" + + @property + def provider_name(self) -> str: + """Return the name of the provider.""" + return "github" + + def build_model(self, model_name: str, api_key: str, **kwargs) -> Model: + """Build and returns GitHubModel.""" + from pydantic_ai.providers.github import GitHubProvider + + model = _build_open_ai_based_model(model_name, GitHubProvider(api_key=api_key), **kwargs) + + self.log.info("Model %s initialized for provider %s", model_name, self.provider_name) + return model + + +class ModelProviderFactory: + """Factory class for model providers.""" + + model_providers: dict[str, ModelProvider] = {} + _initialized: bool = False + + def __init__(self): + self._initialize_default_model_providers() + + @classmethod + def _initialize_default_model_providers(cls): + if cls._initialized: + return + # TODO Implement more default model providers https://ai.pydantic.dev/models/overview/ + defaults = [ + AnthropicModelProvider(), + GithubModelProvider(), + GoogleModelProvider(), + OpenAIModelProvider(), + ] + for provider in defaults: + cls.register_model_provider(provider) + cls._initialized = True + + @classmethod + def register_model_provider(cls, provider: ModelProvider) -> None: + """Register a model provider.""" + cls.model_providers[provider.provider_name] = provider + + def get_model_provider(self, provider_name: str) -> ModelProvider: + """ + Get the model provider for the given provider name. + + eg: provider names are: openai or google or claude + """ + if provider_name and provider_name not in self.model_providers: + raise ValueError(f"Model provider {provider_name} is not registered.") + + return self.model_providers[provider_name] + + @staticmethod + def parse_model_provider_name(provider_model_name: str) -> str | tuple[str, str]: + """Return the provider name and model name from the model name.""" + if ":" not in provider_model_name: + raise ValueError( + f"Invalid model name {provider_model_name}. Model name must be in the format provider:model_name, e.g. github:openai/gpt-4o-mini" + ) Review Comment: This might not always work --- example using Anthropic model via bedrock ########## providers/common/ai/pyproject.toml: ########## @@ -59,13 +59,28 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=3.0.0", + "pydantic-ai-slim[anthropic,google,openai]>=1.58.0", Review Comment: Will help with dependency conflicts ########## providers/common/ai/src/airflow/providers/common/ai/evals/llm_sql.py: ########## @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one Review Comment: Should this module be better to just call it `providers/common/ai/src/airflow/providers/common/ai/evals/sql.py` ########## providers/common/ai/docs/operators.rst: ########## @@ -0,0 +1,62 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +LLM Operators +============= + +LLM Operators are specialized Airflow operators designed to generate sql queries using Large Language Models (LLMs) based on provided prompts. +These operators leverage the capabilities DBApi Hook in airflow to get the database schema and use it to generate SQL queries. see the currently supported databases to extract schema dynamically using +existing DBApi Hooks or optionally provide the database schema in the input datasource config. + +Current LLM Operators uses the Pydantic AI framework to connect to different LLM providers. + +Supported LLM Providers +----------------------- + +- OpenAI +- Google +- Anthropic +- GitHub + +* See for more configuration details ``https://ai.pydantic.dev/models/overview/`` + +Supported Databases +------------------- +Supported databases to extract schema dynamically using existing DBApi Hook. Review Comment: ```suggestion Supported databases to extract schema dynamically using existing ``DBApiHook``. ``` ########## providers/common/ai/docs/operators.rst: ########## @@ -0,0 +1,62 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +LLM Operators +============= + +LLM Operators are specialized Airflow operators designed to generate sql queries using Large Language Models (LLMs) based on provided prompts. +These operators leverage the capabilities DBApi Hook in airflow to get the database schema and use it to generate SQL queries. see the currently supported databases to extract schema dynamically using +existing DBApi Hooks or optionally provide the database schema in the input datasource config. + +Current LLM Operators uses the Pydantic AI framework to connect to different LLM providers. + +Supported LLM Providers +----------------------- + +- OpenAI +- Google Review Comment: Gemini ########## providers/common/ai/pyproject.toml: ########## @@ -59,13 +59,28 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=3.0.0", + "pydantic-ai-slim[anthropic,google,openai]>=1.58.0", Review Comment: I don't think we should hardcode all providers, we should just depend on pydantic-ai-slim and keep everything else optional since someone might just use anthropic vs openai ########## providers/common/ai/docs/operators.rst: ########## @@ -0,0 +1,62 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +LLM Operators +============= + +LLM Operators are specialized Airflow operators designed to generate sql queries using Large Language Models (LLMs) based on provided prompts. +These operators leverage the capabilities DBApi Hook in airflow to get the database schema and use it to generate SQL queries. see the currently supported databases to extract schema dynamically using Review Comment: and then in a subheading you can talk about SQL specific capability and `DbApiHook` ########## providers/common/ai/src/airflow/providers/common/ai/operators/base_llm.py: ########## @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from pydantic_ai.agent import Agent +from pydantic_evals import Dataset + +from airflow.providers.common.ai.evals.llm_sql import ValidateSQL, build_test_case +from airflow.providers.common.ai.exceptions import AgentResponseEvaluationFailure, PromptBuildError +from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook +from airflow.sdk import BaseOperator + +if TYPE_CHECKING: + from airflow.providers.common.ai.configs.datasource import DataSourceConfig + + +class BaseLLMOperator(BaseOperator): + """Base operator for LLM based tasks.""" + + BLOCKED_KEYWORDS = ["DROP", "TRUNCATE", "DELETE FROM", "ALTER TABLE", "GRANT", "REVOKE"] Review Comment: This feels like something that should only belong to SQL specific operation -- not the base class unusable. ########## providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py: ########## @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +from airflow.providers.common.ai.operators.base_llm import BaseLLMOperator + +if TYPE_CHECKING: + from pydantic_ai.agent import AgentRunResult + + from airflow.providers.common.ai.configs.datasource import DataSourceConfig + + +class SQLQueryResponseOutputType(BaseModel): + """Output type LLM Sql query generate.""" + + sql_query_prompt_dict: dict[str, str] + + +class LLMSQLQueryOperator(BaseLLMOperator): + """Operator to generate SQL queries based on prompts for multiple datasources.""" + + def __init__( + self, datasource_configs: list[DataSourceConfig], provider_model: str | None = None, **kwargs + ): + super().__init__(datasource_configs=datasource_configs, **kwargs) + self.provider_model = provider_model + + def execute(self, context): + """Execute LLM Sql query operator.""" + return super().execute(context) + + @property + def get_output_type(self): + """Output type for LLM Sql query generates.""" + return SQLQueryResponseOutputType + + @property + def get_instruction(self): + """Instruction for LLM Agent.""" + db_names = [] + for config in self.datasource_configs: + if config.db_name is None: + config.db_name = config.uri.split("://")[1] + db_names.append(config.db_name) + unique_db_names = set(db_names) + db_name_str = ", ".join(unique_db_names) + + if self.instruction is None: + self.instruction = ( + f"You are a SQL expert integrated with {db_name_str}, Your task is to generate SQL query's based on the prompts and" Review Comment: ```suggestion f"You are a SQL expert integrated with {db_name_str}, Your task is to generate SQL queries based on the prompts and" ``` ########## providers/common/ai/pyproject.toml: ########## @@ -59,13 +59,28 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=3.0.0", + "pydantic-ai-slim[anthropic,google,openai]>=1.58.0", + "pydantic-evals>=1.58.0", Review Comment: `pydantic-evals` is only used for post-generation SQL validation (evaluate_result), not core operator functionality. Making it a hard dependency adds weight and compounds the dependency resolution problem imo. ########## providers/postgres/src/airflow/providers/postgres/hooks/postgres.py: ########## @@ -691,3 +691,11 @@ def insert_rows( nb_rows += len(chunked_rows) self.log.info("Loaded %s rows into %s so far", nb_rows, table) self.log.info("Done loading. Loaded a total of %s rows into %s", nb_rows, table) + + def get_schema(self, table_name: str): + from airflow.providers.common.sql.hooks.handlers import fetch_all_handler + + return self.run( + sql=f"""SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';""", Review Comment: Yup need to fix this ########## providers/postgres/src/airflow/providers/postgres/hooks/postgres.py: ########## @@ -691,3 +691,11 @@ def insert_rows( nb_rows += len(chunked_rows) self.log.info("Loaded %s rows into %s so far", nb_rows, table) self.log.info("Done loading. Loaded a total of %s rows into %s", nb_rows, table) + + def get_schema(self, table_name: str): + from airflow.providers.common.sql.hooks.handlers import fetch_all_handler + + return self.run( + sql=f"""SELECT column_name, data_type FROM information_schema.columns WHERE table_name = '{table_name}';""", Review Comment: You can pass it as `parameters=parameters` too since `self.run` takes query params too ########## providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py: ########## @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.ai.exceptions import ModelCreationError +from airflow.providers.common.ai.llm_providers.model_providers import ModelProviderFactory +from airflow.sdk import BaseHook, Connection + +if TYPE_CHECKING: + from pydantic_ai.models import Model + + from airflow.providers.common.ai.llm_providers.base import ModelProvider + from airflow.providers.common.sql.hooks.sql import DbApiHook + + +class PydanticAIHook(BaseHook): + """Hook for Pydantic AI.""" + + _model_provider_factory: ModelProviderFactory | None = None + + conn_name_attr = "pydantic_ai_conn_id" + default_conn_name = "pydantic_ai_default" + conn_type = "pydantic_ai" + hook_name = "PydanticAI" + + def __init__( + self, pydantic_ai_conn_id: str = default_conn_name, provider_model: str | None = None, **kwargs + ) -> None: + super().__init__(**kwargs) + self.provider_model = provider_model + self.pydantic_ai_conn_id = pydantic_ai_conn_id + self._api_key: str | None = None + self.connection: Connection | None = None + + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + return { + "hidden_fields": ["schema"], + "relabeling": { + "password": "API Key", + }, + "placeholders": { + "extra": json.dumps( + { + "provider_model": "", + "model_settings": {}, + } + ) + }, + } + + def get_conn(self) -> Connection: + if self.connection is None: + self.connection = self.get_connection(self.pydantic_ai_conn_id) + return self.connection + + def get_provider_model_name_from_conn(self): + return self.get_conn().extra_dejson.get("provider_model") + + @cached_property + def get_api_key_from_conn(self): + return self.get_conn().password Review Comment: Since this is a property field: `_api_key_from_conn` might be a better name than "get_*" ########## providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py: ########## @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +from airflow.providers.common.ai.operators.base_llm import BaseLLMOperator + +if TYPE_CHECKING: + from pydantic_ai.agent import AgentRunResult + + from airflow.providers.common.ai.configs.datasource import DataSourceConfig + + +class SQLQueryResponseOutputType(BaseModel): + """Output type LLM Sql query generate.""" + + sql_query_prompt_dict: dict[str, str] + + +class LLMSQLQueryOperator(BaseLLMOperator): + """Operator to generate SQL queries based on prompts for multiple datasources.""" + + def __init__( + self, datasource_configs: list[DataSourceConfig], provider_model: str | None = None, **kwargs + ): + super().__init__(datasource_configs=datasource_configs, **kwargs) + self.provider_model = provider_model + + def execute(self, context): + """Execute LLM Sql query operator.""" + return super().execute(context) + + @property + def get_output_type(self): + """Output type for LLM Sql query generates.""" + return SQLQueryResponseOutputType + + @property + def get_instruction(self): + """Instruction for LLM Agent.""" + db_names = [] + for config in self.datasource_configs: + if config.db_name is None: + config.db_name = config.uri.split("://")[1] + db_names.append(config.db_name) + unique_db_names = set(db_names) + db_name_str = ", ".join(unique_db_names) + + if self.instruction is None: + self.instruction = ( + f"You are a SQL expert integrated with {db_name_str}, Your task is to generate SQL query's based on the prompts and" + f"return the each query and its prompt in key value pair dict format. Make sure the generated query supports given DatabaseType and It should not generate any query without these dangerous keywords: {self.BLOCKED_KEYWORDS} without where class" Review Comment: what does "without where class" mean here? ########## providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py: ########## @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pydantic import BaseModel + +from airflow.providers.common.ai.operators.base_llm import BaseLLMOperator + +if TYPE_CHECKING: + from pydantic_ai.agent import AgentRunResult + + from airflow.providers.common.ai.configs.datasource import DataSourceConfig + + +class SQLQueryResponseOutputType(BaseModel): + """Output type LLM Sql query generate.""" + + sql_query_prompt_dict: dict[str, str] + + +class LLMSQLQueryOperator(BaseLLMOperator): + """Operator to generate SQL queries based on prompts for multiple datasources.""" + + def __init__( + self, datasource_configs: list[DataSourceConfig], provider_model: str | None = None, **kwargs + ): + super().__init__(datasource_configs=datasource_configs, **kwargs) + self.provider_model = provider_model + + def execute(self, context): + """Execute LLM Sql query operator.""" + return super().execute(context) + + @property + def get_output_type(self): + """Output type for LLM Sql query generates.""" + return SQLQueryResponseOutputType + + @property + def get_instruction(self): + """Instruction for LLM Agent.""" + db_names = [] + for config in self.datasource_configs: + if config.db_name is None: + config.db_name = config.uri.split("://")[1] Review Comment: should DB type come from hook dialect instead of uri.split("://")[1]? ########## providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py: ########## @@ -0,0 +1,114 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from airflow.providers.common.ai.exceptions import ModelCreationError +from airflow.providers.common.ai.llm_providers.model_providers import ModelProviderFactory +from airflow.sdk import BaseHook, Connection + +if TYPE_CHECKING: + from pydantic_ai.models import Model + + from airflow.providers.common.ai.llm_providers.base import ModelProvider + from airflow.providers.common.sql.hooks.sql import DbApiHook + + +class PydanticAIHook(BaseHook): + """Hook for Pydantic AI.""" + + _model_provider_factory: ModelProviderFactory | None = None + + conn_name_attr = "pydantic_ai_conn_id" + default_conn_name = "pydantic_ai_default" + conn_type = "pydantic_ai" + hook_name = "PydanticAI" + + def __init__( + self, pydantic_ai_conn_id: str = default_conn_name, provider_model: str | None = None, **kwargs + ) -> None: + super().__init__(**kwargs) + self.provider_model = provider_model + self.pydantic_ai_conn_id = pydantic_ai_conn_id + self._api_key: str | None = None + self.connection: Connection | None = None + + @classmethod + def get_ui_field_behaviour(cls) -> dict[str, Any]: + return { + "hidden_fields": ["schema"], + "relabeling": { + "password": "API Key", + }, + "placeholders": { + "extra": json.dumps( + { + "provider_model": "", + "model_settings": {}, + } + ) + }, + } + + def get_conn(self) -> Connection: + if self.connection is None: + self.connection = self.get_connection(self.pydantic_ai_conn_id) + return self.connection + + def get_provider_model_name_from_conn(self): + return self.get_conn().extra_dejson.get("provider_model") + + @cached_property + def get_api_key_from_conn(self): + return self.get_conn().password + + @classmethod + def get_provider_model_factory(cls): + if cls._model_provider_factory is None: + cls._model_provider_factory = ModelProviderFactory() + return cls._model_provider_factory + + @classmethod + def register_model_provider(cls, provider: ModelProvider): + cls.get_provider_model_factory().register_model_provider(provider) + + def get_model(self, **kwargs) -> Model: + try: + provider_model_name = self.provider_model or self.get_provider_model_name_from_conn() + if not provider_model_name: + raise ValueError("No provider model name provided") + provider_name, model_name = self.get_provider_model_factory().parse_model_provider_name( + provider_model_name + ) + + settings = self.get_conn().extra_dejson.get("model_settings") + if settings: + kwargs["model_settings"] = settings + return self._model_provider_factory.get_model_provider(provider_name).build_model( + model_name, api_key=self.get_api_key_from_conn, **kwargs + ) + except Exception as e: + raise ModelCreationError(f"Error building model: {e}") + + @staticmethod Review Comment: Not sure yet, maybe a MixIn might be better ? ########## airflow-core/src/airflow/utils/db.py: ########## @@ -628,6 +628,12 @@ def get_default_connections(): port=2135, extra={"database": "/local"}, ), + Connection( + conn_id="pydantic_ai_default", + conn_type="pydantic_ai", + password="password", + extra={"model_name": ""}, Review Comment: Should default connection extra use `provider_model` instead of `model_name`, since `PydanticAIHook` reads `provider_model`? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
