This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a1daaf1b60e Replace deprecated `pydantic-ai` MCP classes with
`MCPToolset` in `common.ai` (#69006)
a1daaf1b60e is described below
commit a1daaf1b60ea30764e2f02a07fd27aa5146595f6
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat Jun 27 15:57:10 2026 +0100
Replace deprecated `pydantic-ai` MCP classes with `MCPToolset` in
`common.ai` (#69006)
---
providers/common/ai/docs/toolsets.rst | 30 ++---
.../common/ai/example_dags/example_mcp.py | 11 +-
.../src/airflow/providers/common/ai/hooks/mcp.py | 32 ++---
.../common/ai/toolsets/langchain_bridge.py | 6 +-
.../airflow/providers/common/ai/toolsets/mcp.py | 13 +-
.../ai/tests/unit/common/ai/hooks/test_mcp.py | 146 ++++++++++++---------
6 files changed, 130 insertions(+), 108 deletions(-)
diff --git a/providers/common/ai/docs/toolsets.rst
b/providers/common/ai/docs/toolsets.rst
index d284227c0ac..fd2927187c9 100644
--- a/providers/common/ai/docs/toolsets.rst
+++ b/providers/common/ai/docs/toolsets.rst
@@ -42,11 +42,10 @@ passed to any pydantic-ai ``Agent``, including via
.. note::
``AgentOperator`` accepts **any** ``AbstractToolset`` implementation — not
- just the Airflow-native toolsets above. PydanticAI's own MCP server
- classes (``MCPServerStreamableHTTP``, ``MCPServerSSE``, ``MCPServerStdio``)
- and third-party toolsets work too. The Airflow-native toolsets add
- connection management, secret backend integration, and the connection UI,
- but you are not locked in.
+ just the Airflow-native toolsets above. PydanticAI's own ``MCPToolset``
+ (built over a FastMCP transport) and third-party toolsets work too. The
+ Airflow-native toolsets add connection management, secret backend
+ integration, and the connection UI, but you are not locked in.
Using Toolsets Directly with PydanticAI
@@ -336,29 +335,30 @@ Using Multiple MCP Servers
],
)
-Direct PydanticAI MCP Servers
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Direct PydanticAI MCP Toolsets
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
-For prototyping or when you want full PydanticAI control, you can pass MCP
-server instances directly — no Airflow connection needed:
+For prototyping or when you want full PydanticAI control, you can pass
+``MCPToolset`` instances directly — no Airflow connection needed:
.. code-block:: python
- from pydantic_ai.mcp import MCPServerStreamableHTTP, MCPServerStdio
+ from fastmcp.client.transports import StdioTransport
+ from pydantic_ai.mcp import MCPToolset
AgentOperator(
task_id="direct_mcp",
prompt="What tools are available?",
llm_conn_id="pydanticai_default",
toolsets=[
- MCPServerStreamableHTTP("http://localhost:3001/mcp"),
- MCPServerStdio("uvx", args=["mcp-run-python"]),
+ MCPToolset("http://localhost:3001/mcp"),
+ MCPToolset(StdioTransport(command="uvx", args=["mcp-run-python"])),
],
)
-This works because PydanticAI's MCP server classes implement
-``AbstractToolset``. The tradeoff: URLs and credentials are hardcoded in DAG
-code instead of being managed through Airflow connections and secret backends.
+This works because PydanticAI's ``MCPToolset`` implements ``AbstractToolset``.
+The tradeoff: URLs and credentials are hardcoded in DAG code instead of being
+managed through Airflow connections and secret backends.
.. _agent-skills:
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
index 80233ea7c75..f2e9a423679 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
@@ -74,19 +74,20 @@ example_mcp_multiple_servers()
# ---------------------------------------------------------------------------
-# 3. Direct PydanticAI MCP servers (no Airflow connection needed)
+# 3. Direct PydanticAI MCP toolsets (no Airflow connection needed)
# ---------------------------------------------------------------------------
-# AgentOperator accepts any PydanticAI AbstractToolset, including MCP servers
+# AgentOperator accepts any PydanticAI AbstractToolset, including MCPToolset
# directly. Use this for prototyping or when you want full PydanticAI control.
#
-# from pydantic_ai.mcp import MCPServerStreamableHTTP, MCPServerStdio
+# from fastmcp.client.transports import StdioTransport
+# from pydantic_ai.mcp import MCPToolset
#
# AgentOperator(
# task_id="direct_mcp",
# prompt="What tools are available?",
# llm_conn_id="pydanticai_default",
# toolsets=[
-# MCPServerStreamableHTTP("http://localhost:3001/mcp"),
-# MCPServerStdio("uvx", args=["mcp-run-python"]),
+# MCPToolset("http://localhost:3001/mcp"),
+# MCPToolset(StdioTransport(command="uvx", args=["mcp-run-python"])),
# ],
# )
diff --git a/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py
b/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py
index f6749d280f3..b53ea712e8c 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py
@@ -47,7 +47,7 @@ class MCPHook(BaseHook):
- **Extra.transport**: Transport type — ``http`` (default), ``sse``,
or ``stdio``
- **Extra.command**: Command to run for stdio transport (e.g. ``uvx``)
- **Extra.args**: Command arguments for stdio transport (e.g.
``["mcp-run-python"]``)
- - **Extra.timeout**: Connection timeout in seconds for stdio (default:
10)
+ - **Extra.timeout**: Connection init timeout in seconds for stdio
(default: 10)
For HTTP/SSE transports the ``Authorization`` header is, by default, a
static
``Bearer`` token taken from the connection ``password``. Endpoints that
require
@@ -121,14 +121,18 @@ class MCPHook(BaseHook):
def get_conn(self) -> Any:
"""
- Return a configured PydanticAI MCP server instance.
+ Return a configured PydanticAI MCP toolset instance.
- Creates the appropriate MCP server based on the transport type
- in the connection's extra field:
+ Builds a :class:`~pydantic_ai.mcp.MCPToolset` over the FastMCP
transport
+ matching the transport type in the connection's extra field:
- - ``http`` (default): :class:`~pydantic_ai.mcp.MCPServerStreamableHTTP`
- - ``sse``: :class:`~pydantic_ai.mcp.MCPServerSSE`
- - ``stdio``: :class:`~pydantic_ai.mcp.MCPServerStdio`
+ - ``http`` (default):
``fastmcp.client.transports.StreamableHttpTransport``
+ - ``sse``: ``fastmcp.client.transports.SSETransport``
+ - ``stdio``: ``fastmcp.client.transports.StdioTransport``
+
+ When ``tool_prefix`` is set the toolset is wrapped via
+ :meth:`~pydantic_ai.toolsets.abstract.AbstractToolset.prefixed`, so a
+ prefix of ``"weather"`` yields tool names like
``weather_get_forecast``.
The result is cached for the lifetime of this hook instance.
"""
@@ -136,7 +140,8 @@ class MCPHook(BaseHook):
return self._server
try:
- from pydantic_ai.mcp import MCPServerSSE, MCPServerStdio,
MCPServerStreamableHTTP
+ from fastmcp.client.transports import SSETransport,
StdioTransport, StreamableHttpTransport
+ from pydantic_ai.mcp import MCPToolset
except ImportError:
raise ImportError(
'MCP support requires the `mcp` package. Install it with: pip
install "pydantic-ai-slim[mcp]"'
@@ -149,15 +154,11 @@ class MCPHook(BaseHook):
if transport == "http":
if not conn.host:
raise ValueError(f"Connection {self.mcp_conn_id!r} requires a
host URL for HTTP transport.")
- self._server = MCPServerStreamableHTTP(
- conn.host, headers=self._auth_headers(conn),
tool_prefix=self.tool_prefix
- )
+ toolset = MCPToolset(StreamableHttpTransport(conn.host,
headers=self._auth_headers(conn)))
elif transport == "sse":
if not conn.host:
raise ValueError(f"Connection {self.mcp_conn_id!r} requires a
host URL for SSE transport.")
- self._server = MCPServerSSE(
- conn.host, headers=self._auth_headers(conn),
tool_prefix=self.tool_prefix
- )
+ toolset = MCPToolset(SSETransport(conn.host,
headers=self._auth_headers(conn)))
elif transport == "stdio":
command = extra.get("command")
if not command:
@@ -168,13 +169,14 @@ class MCPHook(BaseHook):
if isinstance(args, str):
args = [args]
timeout = extra.get("timeout", 10)
- self._server = MCPServerStdio(command, args=args, timeout=timeout,
tool_prefix=self.tool_prefix)
+ toolset = MCPToolset(StdioTransport(command=command, args=args),
init_timeout=timeout)
else:
raise ValueError(
f"Unknown transport {transport!r} in connection
{self.mcp_conn_id!r}. "
"Supported: 'http', 'sse', 'stdio'."
)
+ self._server = toolset.prefixed(self.tool_prefix) if self.tool_prefix
else toolset
return self._server
def test_connection(self) -> tuple[bool, str]:
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/langchain_bridge.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/langchain_bridge.py
index 3c2765c1a70..35876d82930 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/langchain_bridge.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/langchain_bridge.py
@@ -97,9 +97,9 @@ def airflow_toolset_to_langchain_tools(
and every ``call_tool`` each run under their own event loop, and
pydantic-ai
opens and tears the connection down around each one. For
``MCPToolset`` this
means the server is reconnected on every tool call. That is fine for
- stateless tools (and for HTTP/SSE servers, modulo per-call latency),
but an
- ``MCPServerStdio`` server, or any server that keeps state between
calls,
- will lose that state because each call starts a fresh process/session.
+ stateless tools (and for HTTP/SSE servers, modulo per-call latency),
but a
+ stdio server, or any server that keeps state between calls, will lose
that
+ state because each call starts a fresh process/session.
.. note::
A pydantic-ai toolset is normally driven inside an agent run, where a
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py
index d7aae3f9a6a..5a35cc3f54d 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py
@@ -35,18 +35,17 @@ class MCPToolset(AbstractToolset[Any]):
Reads MCP server transport type, URL, command, and credentials from the
connection via :class:`~airflow.providers.common.ai.hooks.mcp.MCPHook` and
- creates the appropriate PydanticAI MCP server instance.
- All ``AbstractToolset`` methods delegate to the underlying MCP server.
+ builds the matching PydanticAI :class:`~pydantic_ai.mcp.MCPToolset`.
+ All ``AbstractToolset`` methods delegate to the underlying MCP toolset.
This is the recommended way to use MCP servers in Airflow — it stores
server configuration in Airflow connections (and secret backends) rather
than hard-coding URLs and credentials in DAG code.
- If you prefer full PydanticAI control, you can pass MCP server instances
- directly to ``AgentOperator(toolsets=[...])``, since
- :class:`~pydantic_ai.mcp.MCPServerStreamableHTTP`,
- :class:`~pydantic_ai.mcp.MCPServerSSE`, and
- :class:`~pydantic_ai.mcp.MCPServerStdio` all implement ``AbstractToolset``.
+ If you prefer full PydanticAI control, you can pass a
+ :class:`~pydantic_ai.mcp.MCPToolset` (built over a FastMCP transport)
+ directly to ``AgentOperator(toolsets=[...])``, since it implements
+ ``AbstractToolset``.
For MCP endpoints that need a freshly minted or short-lived token (e.g. a
Snowflake managed MCP server authenticated with a key-pair JWT, or OAuth /
diff --git a/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py
b/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py
index 2928376d699..6e7a5629c3f 100644
--- a/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py
+++ b/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py
@@ -24,11 +24,13 @@ import pytest
from airflow.models.connection import Connection
from airflow.providers.common.ai.hooks.mcp import MCPHook
-# The hook imports MCP classes lazily inside get_conn(), so we must patch
-# them at their source in pydantic_ai.mcp rather than on the hook module.
-_MCP_HTTP = "pydantic_ai.mcp.MCPServerStreamableHTTP"
-_MCP_SSE = "pydantic_ai.mcp.MCPServerSSE"
-_MCP_STDIO = "pydantic_ai.mcp.MCPServerStdio"
+# The hook imports these lazily inside get_conn(), so we patch them at their
+# source modules rather than on the hook module. MCPToolset is the unified
+# pydantic-ai entrypoint; the three transports come from FastMCP.
+_MCP_TOOLSET = "pydantic_ai.mcp.MCPToolset"
+_HTTP_TRANSPORT = "fastmcp.client.transports.StreamableHttpTransport"
+_SSE_TRANSPORT = "fastmcp.client.transports.SSETransport"
+_STDIO_TRANSPORT = "fastmcp.client.transports.StdioTransport"
class TestMCPHookInit:
@@ -46,8 +48,9 @@ class TestMCPHookInit:
class TestMCPHookGetConn:
- @patch(_MCP_HTTP)
- def test_http_transport(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_HTTP_TRANSPORT, autospec=True)
+ def test_http_transport(self, mock_transport_cls, mock_toolset_cls):
hook = MCPHook(mcp_conn_id="test_conn")
conn = Connection(
conn_id="test_conn",
@@ -57,11 +60,13 @@ class TestMCPHookGetConn:
with patch.object(hook, "get_connection", return_value=conn):
result = hook.get_conn()
- mock_server_cls.assert_called_once_with("http://localhost:3001/mcp",
headers=None, tool_prefix=None)
- assert result is mock_server_cls.return_value
+
mock_transport_cls.assert_called_once_with("http://localhost:3001/mcp",
headers=None)
+
mock_toolset_cls.assert_called_once_with(mock_transport_cls.return_value)
+ assert result is mock_toolset_cls.return_value
- @patch(_MCP_HTTP)
- def test_http_is_default_transport(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_HTTP_TRANSPORT, autospec=True)
+ def test_http_is_default_transport(self, mock_transport_cls,
mock_toolset_cls):
hook = MCPHook(mcp_conn_id="test_conn")
conn = Connection(
conn_id="test_conn",
@@ -71,10 +76,12 @@ class TestMCPHookGetConn:
with patch.object(hook, "get_connection", return_value=conn):
hook.get_conn()
- mock_server_cls.assert_called_once_with("http://localhost:3001/mcp",
headers=None, tool_prefix=None)
+
mock_transport_cls.assert_called_once_with("http://localhost:3001/mcp",
headers=None)
+
mock_toolset_cls.assert_called_once_with(mock_transport_cls.return_value)
- @patch(_MCP_HTTP)
- def test_http_with_auth_token(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_HTTP_TRANSPORT, autospec=True)
+ def test_http_with_auth_token(self, mock_transport_cls, mock_toolset_cls):
hook = MCPHook(mcp_conn_id="test_conn")
conn = Connection(
conn_id="test_conn",
@@ -85,14 +92,14 @@ class TestMCPHookGetConn:
with patch.object(hook, "get_connection", return_value=conn):
hook.get_conn()
- mock_server_cls.assert_called_once_with(
+ mock_transport_cls.assert_called_once_with(
"http://localhost:3001/mcp",
headers={"Authorization": "Bearer my-secret-token"},
- tool_prefix=None,
)
- @patch(_MCP_HTTP)
- def test_passes_tool_prefix(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_HTTP_TRANSPORT, autospec=True)
+ def test_passes_tool_prefix(self, mock_transport_cls, mock_toolset_cls):
hook = MCPHook(mcp_conn_id="test_conn", tool_prefix="weather")
conn = Connection(
conn_id="test_conn",
@@ -100,14 +107,16 @@ class TestMCPHookGetConn:
host="http://localhost:3001/mcp",
)
with patch.object(hook, "get_connection", return_value=conn):
- hook.get_conn()
+ result = hook.get_conn()
- mock_server_cls.assert_called_once_with(
- "http://localhost:3001/mcp", headers=None, tool_prefix="weather"
- )
+ # tool_prefix is applied by wrapping the toolset, not via a
constructor arg.
+
mock_toolset_cls.assert_called_once_with(mock_transport_cls.return_value)
+
mock_toolset_cls.return_value.prefixed.assert_called_once_with("weather")
+ assert result is mock_toolset_cls.return_value.prefixed.return_value
- @patch(_MCP_SSE)
- def test_sse_transport(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_SSE_TRANSPORT, autospec=True)
+ def test_sse_transport(self, mock_transport_cls, mock_toolset_cls):
hook = MCPHook(mcp_conn_id="test_conn")
conn = Connection(
conn_id="test_conn",
@@ -118,11 +127,13 @@ class TestMCPHookGetConn:
with patch.object(hook, "get_connection", return_value=conn):
result = hook.get_conn()
- mock_server_cls.assert_called_once_with("http://localhost:3001/sse",
headers=None, tool_prefix=None)
- assert result is mock_server_cls.return_value
+
mock_transport_cls.assert_called_once_with("http://localhost:3001/sse",
headers=None)
+
mock_toolset_cls.assert_called_once_with(mock_transport_cls.return_value)
+ assert result is mock_toolset_cls.return_value
- @patch(_MCP_STDIO)
- def test_stdio_transport(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_STDIO_TRANSPORT, autospec=True)
+ def test_stdio_transport(self, mock_transport_cls, mock_toolset_cls):
hook = MCPHook(mcp_conn_id="test_conn")
conn = Connection(
conn_id="test_conn",
@@ -132,11 +143,13 @@ class TestMCPHookGetConn:
with patch.object(hook, "get_connection", return_value=conn):
result = hook.get_conn()
- mock_server_cls.assert_called_once_with("uvx",
args=["mcp-run-python"], timeout=10, tool_prefix=None)
- assert result is mock_server_cls.return_value
+ mock_transport_cls.assert_called_once_with(command="uvx",
args=["mcp-run-python"])
+
mock_toolset_cls.assert_called_once_with(mock_transport_cls.return_value,
init_timeout=10)
+ assert result is mock_toolset_cls.return_value
- @patch(_MCP_STDIO)
- def test_stdio_custom_timeout(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_STDIO_TRANSPORT, autospec=True)
+ def test_stdio_custom_timeout(self, mock_transport_cls, mock_toolset_cls):
hook = MCPHook(mcp_conn_id="test_conn")
conn = Connection(
conn_id="test_conn",
@@ -148,10 +161,12 @@ class TestMCPHookGetConn:
with patch.object(hook, "get_connection", return_value=conn):
hook.get_conn()
- mock_server_cls.assert_called_once_with("python", args=["-m",
"server"], timeout=30, tool_prefix=None)
+ mock_transport_cls.assert_called_once_with(command="python",
args=["-m", "server"])
+
mock_toolset_cls.assert_called_once_with(mock_transport_cls.return_value,
init_timeout=30)
- @patch(_MCP_STDIO)
- def test_args_string_converted_to_list(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_STDIO_TRANSPORT, autospec=True)
+ def test_args_string_converted_to_list(self, mock_transport_cls,
mock_toolset_cls):
hook = MCPHook(mcp_conn_id="test_conn")
conn = Connection(
conn_id="test_conn",
@@ -161,7 +176,7 @@ class TestMCPHookGetConn:
with patch.object(hook, "get_connection", return_value=conn):
hook.get_conn()
- mock_server_cls.assert_called_once_with("uvx",
args=["mcp-run-python"], timeout=10, tool_prefix=None)
+ mock_transport_cls.assert_called_once_with(command="uvx",
args=["mcp-run-python"])
def test_http_without_host_raises(self):
hook = MCPHook(mcp_conn_id="test_conn")
@@ -203,8 +218,9 @@ class TestMCPHookGetConn:
with pytest.raises(ValueError, match="Unknown transport"):
hook.get_conn()
- @patch(_MCP_HTTP)
- def test_caches_server(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_HTTP_TRANSPORT, autospec=True)
+ def test_caches_server(self, mock_transport_cls, mock_toolset_cls):
hook = MCPHook(mcp_conn_id="test_conn")
conn = Connection(conn_id="test_conn", conn_type="mcp",
host="http://localhost:3001/mcp")
with patch.object(hook, "get_connection", return_value=conn):
@@ -212,12 +228,13 @@ class TestMCPHookGetConn:
second = hook.get_conn()
assert first is second
- mock_server_cls.assert_called_once()
+ mock_toolset_cls.assert_called_once()
class TestMCPHookTestConnection:
- @patch(_MCP_HTTP)
- def test_successful_config(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_HTTP_TRANSPORT, autospec=True)
+ def test_successful_config(self, mock_transport_cls, mock_toolset_cls):
hook = MCPHook(mcp_conn_id="test_conn")
conn = Connection(conn_id="test_conn", conn_type="mcp",
host="http://localhost:3001/mcp")
with patch.object(hook, "get_connection", return_value=conn):
@@ -249,21 +266,22 @@ class TestMCPHookUIFieldBehaviour:
class TestMCPHookTokenProvider:
- @patch(_MCP_HTTP)
- def test_http_uses_token_provider(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_HTTP_TRANSPORT, autospec=True)
+ def test_http_uses_token_provider(self, mock_transport_cls,
mock_toolset_cls):
hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda:
"minted-jwt")
conn = Connection(conn_id="test_conn", conn_type="mcp",
host="http://localhost:3001/mcp")
with patch.object(hook, "get_connection", return_value=conn):
hook.get_conn()
- mock_server_cls.assert_called_once_with(
+ mock_transport_cls.assert_called_once_with(
"http://localhost:3001/mcp",
headers={"Authorization": "Bearer minted-jwt"},
- tool_prefix=None,
)
- @patch(_MCP_HTTP)
- def test_token_provider_overrides_static_password(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_HTTP_TRANSPORT, autospec=True)
+ def test_token_provider_overrides_static_password(self,
mock_transport_cls, mock_toolset_cls):
hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda: "fresh")
conn = Connection(
conn_id="test_conn",
@@ -274,14 +292,14 @@ class TestMCPHookTokenProvider:
with patch.object(hook, "get_connection", return_value=conn):
hook.get_conn()
- mock_server_cls.assert_called_once_with(
+ mock_transport_cls.assert_called_once_with(
"http://localhost:3001/mcp",
headers={"Authorization": "Bearer fresh"},
- tool_prefix=None,
)
- @patch(_MCP_HTTP)
- def test_token_provider_called_when_establishing_connection(self,
mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_HTTP_TRANSPORT, autospec=True)
+ def test_token_provider_called_when_establishing_connection(self,
mock_transport_cls, mock_toolset_cls):
provider = MagicMock(return_value="tok")
hook = MCPHook(mcp_conn_id="test_conn", token_provider=provider)
conn = Connection(conn_id="test_conn", conn_type="mcp",
host="http://localhost:3001/mcp")
@@ -290,22 +308,22 @@ class TestMCPHookTokenProvider:
provider.assert_called_once_with()
- @patch(_MCP_HTTP)
- def test_masks_minted_token(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_HTTP_TRANSPORT, autospec=True)
+ def test_masks_minted_token(self, mock_transport_cls, mock_toolset_cls):
"""The minted token must be registered with secret masking, like
conn.password."""
hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda:
"minted-jwt")
conn = Connection(conn_id="test_conn", conn_type="mcp",
host="http://localhost:3001/mcp")
with (
patch.object(hook, "get_connection", return_value=conn),
- patch("airflow.providers.common.ai.hooks.mcp.mask_secret") as
mock_mask,
+ patch("airflow.providers.common.ai.hooks.mcp.mask_secret",
autospec=True) as mock_mask,
):
hook.get_conn()
mock_mask.assert_called_once_with("minted-jwt")
@pytest.mark.parametrize("bad_token", ["", None], ids=["empty",
"non_string"])
- @patch(_MCP_HTTP)
- def test_invalid_token_raises(self, mock_server_cls, bad_token):
+ def test_invalid_token_raises(self, bad_token):
"""A token_provider returning a non-string or empty value fails
loud."""
hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda:
bad_token)
conn = Connection(conn_id="test_conn", conn_type="mcp",
host="http://localhost:3001/mcp")
@@ -313,8 +331,9 @@ class TestMCPHookTokenProvider:
with pytest.raises(ValueError, match="must return a non-empty
string token"):
hook.get_conn()
- @patch(_MCP_SSE)
- def test_sse_uses_token_provider(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_SSE_TRANSPORT, autospec=True)
+ def test_sse_uses_token_provider(self, mock_transport_cls,
mock_toolset_cls):
hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda:
"minted")
conn = Connection(
conn_id="test_conn",
@@ -325,14 +344,14 @@ class TestMCPHookTokenProvider:
with patch.object(hook, "get_connection", return_value=conn):
hook.get_conn()
- mock_server_cls.assert_called_once_with(
+ mock_transport_cls.assert_called_once_with(
"http://localhost:3001/sse",
headers={"Authorization": "Bearer minted"},
- tool_prefix=None,
)
- @patch(_MCP_STDIO)
- def test_stdio_does_not_invoke_token_provider(self, mock_server_cls):
+ @patch(_MCP_TOOLSET, autospec=True)
+ @patch(_STDIO_TRANSPORT, autospec=True)
+ def test_stdio_does_not_invoke_token_provider(self, mock_transport_cls,
mock_toolset_cls):
"""stdio has no HTTP headers, so the token provider must not be
called."""
provider = MagicMock(return_value="tok")
hook = MCPHook(mcp_conn_id="test_conn", token_provider=provider)
@@ -345,4 +364,5 @@ class TestMCPHookTokenProvider:
hook.get_conn()
provider.assert_not_called()
- mock_server_cls.assert_called_once_with("uvx",
args=["mcp-run-python"], timeout=10, tool_prefix=None)
+ mock_transport_cls.assert_called_once_with(command="uvx",
args=["mcp-run-python"])
+
mock_toolset_cls.assert_called_once_with(mock_transport_cls.return_value,
init_timeout=10)