This is an automated email from the ASF dual-hosted git repository.
vikramkoka pushed a commit to branch aip99-langchain
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/aip99-langchain by this push:
new a34521eab7a Add LangChain hook to common.ai provider
a34521eab7a is described below
commit a34521eab7ab6554aaa179372bc320ad4de1f0c9
Author: Vikram Koka <[email protected]>
AuthorDate: Tue May 19 16:40:56 2026 +0100
Add LangChain hook to common.ai provider
- Adds LangChainHook to bridge Airflow connections to LangChain model
constructors (ChatOpenAI, OpenAIEmbeddings),
using constructor injection for credentials
- Reuses the existing pydanticai connection type so users configure one
connection for PydanticAI, LlamaIndex, and
LangChain
- Follows the same pattern as LlamaIndexHook:
_resolve_connection_kwargs() extracts api_key and base_url from the
Airflow connection and passes them directly to LangChain constructors
- Adds langchain optional dependency extra (langchain>=1.0.0,
langchain-openai>=0.3.0)
What's included
- hooks/langchain.py — LangChainHook(BaseHook) with get_chat_model() and
get_embedding_model()
- tests/unit/common/ai/hooks/test_langchain.py — full test coverage
(init, connection resolution, chat model,
embedding model)
- docs/hooks/langchain.rst — hook documentation with usage examples
- provider.yaml — LangChain integration and hook registration
- pyproject.toml — langchain optional dependency extra
Design decisions
- BaseHook, not BaseAIHook — BaseAIHook is still in development. Will
migrate in a follow-up PR once it ships.
- Constructor injection — credentials passed as api_key=/base_url= kwargs
to LangChain constructors. No environment
variable mutation. Matches the LlamaIndexHook pattern.
- Shared connection type — reuses pydanticai connection type rather than
introducing a new one. One connection works
across all three frameworks.
- No @task.langchain yet — consistent with LlamaIndex (no
@task.llamaindex). Deferred to the BaseAIHook migration PR.
---
providers/common/ai/docs/hooks/langchain.rst | 111 ++++++++++++
providers/common/ai/provider.yaml | 6 +
providers/common/ai/pyproject.toml | 4 +
.../airflow/providers/common/ai/hooks/langchain.py | 97 +++++++++++
.../tests/unit/common/ai/hooks/test_langchain.py | 189 +++++++++++++++++++++
5 files changed, 407 insertions(+)
diff --git a/providers/common/ai/docs/hooks/langchain.rst
b/providers/common/ai/docs/hooks/langchain.rst
new file mode 100644
index 00000000000..758c8b36b81
--- /dev/null
+++ b/providers/common/ai/docs/hooks/langchain.rst
@@ -0,0 +1,111 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _howto/hook:langchain:
+
+``LangChainHook``
+=================
+
+Use :class:`~airflow.providers.common.ai.hooks.langchain.LangChainHook`
+to bridge Airflow connections to LangChain model constructors. The hook
+extracts credentials from an Airflow connection and returns configured
+LangChain model objects (``ChatOpenAI``, ``OpenAIEmbeddings``).
+
+The hook reuses the ``pydanticai`` connection type, so users configure a
+single connection for PydanticAI operators, LlamaIndex operators, and
+LangChain tasks.
+
+.. seealso::
+ :ref:`Connection configuration <howto/connection:pydanticai>`
+
+Basic Usage
+-----------
+
+Use the hook in a ``@task`` function to get a configured chat model:
+
+.. code-block:: python
+
+ from airflow.providers.common.ai.hooks.langchain import LangChainHook
+
+ @task
+ def run_chain(query: str) -> str:
+ hook = LangChainHook(llm_conn_id="pydanticai_default",
llm_model="gpt-4o")
+ llm = hook.get_chat_model()
+
+ from langchain_core.prompts import ChatPromptTemplate
+ from langchain_core.output_parsers import StrOutputParser
+
+ prompt = ChatPromptTemplate.from_template("Summarize: {query}")
+ chain = prompt | llm | StrOutputParser()
+ return chain.invoke({"query": query})
+
+Embedding Models
+----------------
+
+Use :meth:`~LangChainHook.get_embedding_model` for embeddings. A
+separate ``embed_conn_id`` can be used when embedding and chat models
+use different API keys:
+
+.. code-block:: python
+
+ hook = LangChainHook(
+ llm_conn_id="chat_conn",
+ embed_conn_id="embed_conn",
+ embed_model="text-embedding-3-large",
+ llm_model="gpt-4o",
+ )
+ embeddings = hook.get_embedding_model()
+ chat_model = hook.get_chat_model()
+
+Connection Configuration
+------------------------
+
+The hook reads credentials from the Airflow connection:
+
+- **password** -- API key (passed as ``api_key`` to model constructors)
+- **host** -- Base URL (passed as ``base_url``; optional, for custom
+ endpoints or Ollama)
+
+Parameters
+----------
+
+.. list-table::
+ :header-rows: 1
+ :widths: 25 15 60
+
+ * - Parameter
+ - Default
+ - Description
+ * - ``llm_conn_id``
+ - ``pydanticai_default``
+ - Airflow connection ID for the LLM provider.
+ * - ``embed_conn_id``
+ - ``None`` (falls back to ``llm_conn_id``)
+ - Separate connection for embeddings.
+ * - ``embed_model``
+ - ``text-embedding-3-small``
+ - Embedding model name.
+ * - ``llm_model``
+ - ``None``
+ - Chat model name. Required for ``get_chat_model()``.
+
+Dependencies
+------------
+
+Install the ``langchain`` extra to use this hook::
+
+ pip install apache-airflow-providers-common-ai[langchain]
diff --git a/providers/common/ai/provider.yaml
b/providers/common/ai/provider.yaml
index 2a13392ea99..cc716d54a47 100644
--- a/providers/common/ai/provider.yaml
+++ b/providers/common/ai/provider.yaml
@@ -48,6 +48,9 @@ integrations:
- integration-name: MCP Server
external-doc-url: https://modelcontextprotocol.io/
tags: [ai]
+ - integration-name: LangChain
+ external-doc-url: https://python.langchain.com/
+ tags: [ai]
hooks:
- integration-name: Pydantic AI
@@ -56,6 +59,9 @@ hooks:
- integration-name: MCP Server
python-modules:
- airflow.providers.common.ai.hooks.mcp
+ - integration-name: LangChain
+ python-modules:
+ - airflow.providers.common.ai.hooks.langchain
plugins:
- name: hitl_review
diff --git a/providers/common/ai/pyproject.toml
b/providers/common/ai/pyproject.toml
index 57ba93f7461..f44fe09fbd4 100644
--- a/providers/common/ai/pyproject.toml
+++ b/providers/common/ai/pyproject.toml
@@ -95,6 +95,10 @@ dependencies = [
"common.sql" = [
"apache-airflow-providers-common-sql"
]
+"langchain" = [
+ "langchain>=1.0.0",
+ "langchain-openai>=0.3.0",
+]
[dependency-groups]
dev = [
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/hooks/langchain.py
b/providers/common/ai/src/airflow/providers/common/ai/hooks/langchain.py
new file mode 100644
index 00000000000..a92db1228f8
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/langchain.py
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Hook for LangChain integration with Airflow connections."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.common.compat.sdk import BaseHook
+
+if TYPE_CHECKING:
+ from langchain_core.embeddings import Embeddings
+ from langchain_core.language_models.chat_models import BaseChatModel
+
+
+class LangChainHook(BaseHook):
+ """
+ Bridge Airflow connections to LangChain model constructors.
+
+ Reuses the ``pydanticai`` connection type so users configure a single
+ connection for both pydantic-ai operators and LangChain tasks.
+
+ :param llm_conn_id: Airflow connection ID for the LLM/embedding provider.
+ :param embed_conn_id: Separate connection for embeddings. Defaults to
+ ``llm_conn_id`` when not provided.
+ :param embed_model: Embedding model name (e.g. ``text-embedding-3-small``).
+ :param llm_model: Chat model name (e.g. ``gpt-4o``). Only needed when
+ using :meth:`get_chat_model`.
+ """
+
+ conn_name_attr = "llm_conn_id"
+ default_conn_name = "pydanticai_default"
+ conn_type = "pydanticai"
+ hook_name = "LangChain"
+
+ def __init__(
+ self,
+ llm_conn_id: str = "pydanticai_default",
+ embed_conn_id: str | None = None,
+ embed_model: str = "text-embedding-3-small",
+ llm_model: str | None = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.llm_conn_id = llm_conn_id
+ self.embed_conn_id = embed_conn_id or llm_conn_id
+ self.embed_model = embed_model
+ self.llm_model = llm_model
+
+ def _resolve_connection_kwargs(self, conn_id: str) -> dict[str, Any]:
+ """Extract API key and base URL from an Airflow connection."""
+ conn = self.get_connection(conn_id)
+ kwargs: dict[str, Any] = {}
+ if conn.password:
+ kwargs["api_key"] = conn.password
+ if conn.host:
+ kwargs["base_url"] = conn.host
+ return kwargs
+
+ def get_chat_model(self) -> BaseChatModel:
+ """
+ Return a LangChain chat model configured from the Airflow connection.
+
+ Requires ``llm_model`` to be set on the hook.
+ """
+ if not self.llm_model:
+ raise ValueError("llm_model must be set to use get_chat_model()")
+
+ from langchain_openai import ChatOpenAI
+
+ conn_kwargs = self._resolve_connection_kwargs(self.llm_conn_id)
+ return ChatOpenAI(model=self.llm_model, **conn_kwargs)
+
+ def get_embedding_model(self) -> Embeddings:
+ """
+ Return a LangChain embedding model configured from the Airflow
connection.
+
+ Uses ``embed_conn_id`` (falls back to ``llm_conn_id``) for credentials.
+ """
+ from langchain_openai import OpenAIEmbeddings
+
+ conn_kwargs = self._resolve_connection_kwargs(self.embed_conn_id)
+ return OpenAIEmbeddings(model=self.embed_model, **conn_kwargs)
diff --git a/providers/common/ai/tests/unit/common/ai/hooks/test_langchain.py
b/providers/common/ai/tests/unit/common/ai/hooks/test_langchain.py
new file mode 100644
index 00000000000..8ebd9b3f1b5
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/hooks/test_langchain.py
@@ -0,0 +1,189 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.hooks.langchain import LangChainHook
+
+
+class TestLangChainHookInit:
+ def test_default_params(self):
+ hook = LangChainHook()
+ assert hook.llm_conn_id == "pydanticai_default"
+ assert hook.embed_conn_id == "pydanticai_default"
+ assert hook.embed_model == "text-embedding-3-small"
+ assert hook.llm_model is None
+
+ def test_separate_embed_conn_id(self):
+ hook = LangChainHook(llm_conn_id="llm_conn",
embed_conn_id="embed_conn")
+ assert hook.llm_conn_id == "llm_conn"
+ assert hook.embed_conn_id == "embed_conn"
+
+ def test_embed_conn_defaults_to_llm_conn(self):
+ hook = LangChainHook(llm_conn_id="my_conn")
+ assert hook.embed_conn_id == "my_conn"
+
+ def test_conn_type_is_pydanticai(self):
+ assert LangChainHook.conn_type == "pydanticai"
+ assert LangChainHook.default_conn_name == "pydanticai_default"
+
+
+class TestResolveConnectionKwargs:
+ @patch.object(LangChainHook, "get_connection")
+ def test_extracts_password_as_api_key(self, mock_get_conn):
+ mock_conn = MagicMock()
+ mock_conn.password = "sk-test-key"
+ mock_conn.host = ""
+ mock_get_conn.return_value = mock_conn
+
+ hook = LangChainHook()
+ result = hook._resolve_connection_kwargs("test_conn")
+
+ assert result == {"api_key": "sk-test-key"}
+
+ @patch.object(LangChainHook, "get_connection")
+ def test_extracts_host_as_base_url(self, mock_get_conn):
+ mock_conn = MagicMock()
+ mock_conn.password = ""
+ mock_conn.host = "https://custom.api.com"
+ mock_get_conn.return_value = mock_conn
+
+ hook = LangChainHook()
+ result = hook._resolve_connection_kwargs("test_conn")
+
+ assert result == {"base_url": "https://custom.api.com"}
+
+ @patch.object(LangChainHook, "get_connection")
+ def test_both_password_and_host(self, mock_get_conn):
+ mock_conn = MagicMock()
+ mock_conn.password = "sk-key"
+ mock_conn.host = "https://api.example.com"
+ mock_get_conn.return_value = mock_conn
+
+ hook = LangChainHook()
+ result = hook._resolve_connection_kwargs("test_conn")
+
+ assert result == {"api_key": "sk-key", "base_url":
"https://api.example.com"}
+
+ @patch.object(LangChainHook, "get_connection")
+ def test_empty_fields_return_empty_dict(self, mock_get_conn):
+ mock_conn = MagicMock()
+ mock_conn.password = ""
+ mock_conn.host = ""
+ mock_get_conn.return_value = mock_conn
+
+ hook = LangChainHook()
+ result = hook._resolve_connection_kwargs("test_conn")
+
+ assert result == {}
+
+
+def _make_mock_chat_openai_module():
+ mock_module = MagicMock()
+ mock_cls = MagicMock()
+ mock_module.ChatOpenAI = mock_cls
+ return mock_module, mock_cls
+
+
+def _make_mock_openai_embeddings_module():
+ mock_module = MagicMock()
+ mock_cls = MagicMock()
+ mock_module.OpenAIEmbeddings = mock_cls
+ return mock_module, mock_cls
+
+
+class TestGetChatModel:
+ def test_raises_without_llm_model(self):
+ hook = LangChainHook()
+ with pytest.raises(ValueError, match="llm_model must be set"):
+ hook.get_chat_model()
+
+ @patch.object(LangChainHook, "get_connection")
+ def test_returns_chat_openai(self, mock_get_conn):
+ mock_conn = MagicMock()
+ mock_conn.password = "sk-test"
+ mock_conn.host = ""
+ mock_get_conn.return_value = mock_conn
+
+ mock_module, mock_cls = _make_mock_chat_openai_module()
+
+ hook = LangChainHook(llm_model="gpt-4o")
+ with patch.dict("sys.modules", {"langchain_openai": mock_module}):
+ result = hook.get_chat_model()
+
+ mock_cls.assert_called_once_with(model="gpt-4o", api_key="sk-test")
+ assert result == mock_cls.return_value
+
+ @patch.object(LangChainHook, "get_connection")
+ def test_passes_base_url(self, mock_get_conn):
+ mock_conn = MagicMock()
+ mock_conn.password = "sk-test"
+ mock_conn.host = "https://custom.api.com"
+ mock_get_conn.return_value = mock_conn
+
+ mock_module, mock_cls = _make_mock_chat_openai_module()
+
+ hook = LangChainHook(llm_model="gpt-4o")
+ with patch.dict("sys.modules", {"langchain_openai": mock_module}):
+ hook.get_chat_model()
+
+ mock_cls.assert_called_once_with(
+ model="gpt-4o", api_key="sk-test",
base_url="https://custom.api.com"
+ )
+
+
+class TestGetEmbeddingModel:
+ @patch.object(LangChainHook, "get_connection")
+ def test_returns_openai_embeddings(self, mock_get_conn):
+ mock_conn = MagicMock()
+ mock_conn.password = "sk-test"
+ mock_conn.host = ""
+ mock_get_conn.return_value = mock_conn
+
+ mock_module, mock_cls = _make_mock_openai_embeddings_module()
+
+ hook = LangChainHook(embed_model="text-embedding-3-large")
+ with patch.dict("sys.modules", {"langchain_openai": mock_module}):
+ result = hook.get_embedding_model()
+
+ mock_cls.assert_called_once_with(model="text-embedding-3-large",
api_key="sk-test")
+ assert result == mock_cls.return_value
+
+ @patch.object(LangChainHook, "get_connection")
+ def test_uses_embed_conn_id(self, mock_get_conn):
+ mock_conn_llm = MagicMock()
+ mock_conn_llm.password = "sk-llm"
+ mock_conn_llm.host = ""
+
+ mock_conn_embed = MagicMock()
+ mock_conn_embed.password = "sk-embed"
+ mock_conn_embed.host = ""
+
+ mock_get_conn.side_effect = lambda conn_id: (
+ mock_conn_embed if conn_id == "embed_conn" else mock_conn_llm
+ )
+
+ mock_module, mock_cls = _make_mock_openai_embeddings_module()
+
+ hook = LangChainHook(llm_conn_id="llm_conn",
embed_conn_id="embed_conn")
+ with patch.dict("sys.modules", {"langchain_openai": mock_module}):
+ hook.get_embedding_model()
+
+ mock_cls.assert_called_once_with(model="text-embedding-3-small",
api_key="sk-embed")