This is an automated email from the ASF dual-hosted git repository.

kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a1daaf1b60e Replace deprecated `pydantic-ai` MCP classes with 
`MCPToolset` in `common.ai` (#69006)
a1daaf1b60e is described below

commit a1daaf1b60ea30764e2f02a07fd27aa5146595f6
Author: Kaxil Naik <[email protected]>
AuthorDate: Sat Jun 27 15:57:10 2026 +0100

    Replace deprecated `pydantic-ai` MCP classes with `MCPToolset` in 
`common.ai` (#69006)
---
 providers/common/ai/docs/toolsets.rst              |  30 ++---
 .../common/ai/example_dags/example_mcp.py          |  11 +-
 .../src/airflow/providers/common/ai/hooks/mcp.py   |  32 ++---
 .../common/ai/toolsets/langchain_bridge.py         |   6 +-
 .../airflow/providers/common/ai/toolsets/mcp.py    |  13 +-
 .../ai/tests/unit/common/ai/hooks/test_mcp.py      | 146 ++++++++++++---------
 6 files changed, 130 insertions(+), 108 deletions(-)

diff --git a/providers/common/ai/docs/toolsets.rst 
b/providers/common/ai/docs/toolsets.rst
index d284227c0ac..fd2927187c9 100644
--- a/providers/common/ai/docs/toolsets.rst
+++ b/providers/common/ai/docs/toolsets.rst
@@ -42,11 +42,10 @@ passed to any pydantic-ai ``Agent``, including via
 .. note::
 
     ``AgentOperator`` accepts **any** ``AbstractToolset`` implementation — not
-    just the Airflow-native toolsets above. PydanticAI's own MCP server
-    classes (``MCPServerStreamableHTTP``, ``MCPServerSSE``, ``MCPServerStdio``)
-    and third-party toolsets work too. The Airflow-native toolsets add
-    connection management, secret backend integration, and the connection UI,
-    but you are not locked in.
+    just the Airflow-native toolsets above. PydanticAI's own ``MCPToolset``
+    (built over a FastMCP transport) and third-party toolsets work too. The
+    Airflow-native toolsets add connection management, secret backend
+    integration, and the connection UI, but you are not locked in.
 
 
 Using Toolsets Directly with PydanticAI
@@ -336,29 +335,30 @@ Using Multiple MCP Servers
         ],
     )
 
-Direct PydanticAI MCP Servers
-^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+Direct PydanticAI MCP Toolsets
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
-For prototyping or when you want full PydanticAI control, you can pass MCP
-server instances directly — no Airflow connection needed:
+For prototyping or when you want full PydanticAI control, you can pass
+``MCPToolset`` instances directly — no Airflow connection needed:
 
 .. code-block:: python
 
-    from pydantic_ai.mcp import MCPServerStreamableHTTP, MCPServerStdio
+    from fastmcp.client.transports import StdioTransport
+    from pydantic_ai.mcp import MCPToolset
 
     AgentOperator(
         task_id="direct_mcp",
         prompt="What tools are available?",
         llm_conn_id="pydanticai_default",
         toolsets=[
-            MCPServerStreamableHTTP("http://localhost:3001/mcp";),
-            MCPServerStdio("uvx", args=["mcp-run-python"]),
+            MCPToolset("http://localhost:3001/mcp";),
+            MCPToolset(StdioTransport(command="uvx", args=["mcp-run-python"])),
         ],
     )
 
-This works because PydanticAI's MCP server classes implement
-``AbstractToolset``. The tradeoff: URLs and credentials are hardcoded in DAG
-code instead of being managed through Airflow connections and secret backends.
+This works because PydanticAI's ``MCPToolset`` implements ``AbstractToolset``.
+The tradeoff: URLs and credentials are hardcoded in DAG code instead of being
+managed through Airflow connections and secret backends.
 
 
 .. _agent-skills:
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
index 80233ea7c75..f2e9a423679 100644
--- 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_mcp.py
@@ -74,19 +74,20 @@ example_mcp_multiple_servers()
 
 
 # ---------------------------------------------------------------------------
-# 3. Direct PydanticAI MCP servers (no Airflow connection needed)
+# 3. Direct PydanticAI MCP toolsets (no Airflow connection needed)
 # ---------------------------------------------------------------------------
-# AgentOperator accepts any PydanticAI AbstractToolset, including MCP servers
+# AgentOperator accepts any PydanticAI AbstractToolset, including MCPToolset
 # directly. Use this for prototyping or when you want full PydanticAI control.
 #
-#   from pydantic_ai.mcp import MCPServerStreamableHTTP, MCPServerStdio
+#   from fastmcp.client.transports import StdioTransport
+#   from pydantic_ai.mcp import MCPToolset
 #
 #   AgentOperator(
 #       task_id="direct_mcp",
 #       prompt="What tools are available?",
 #       llm_conn_id="pydanticai_default",
 #       toolsets=[
-#           MCPServerStreamableHTTP("http://localhost:3001/mcp";),
-#           MCPServerStdio("uvx", args=["mcp-run-python"]),
+#           MCPToolset("http://localhost:3001/mcp";),
+#           MCPToolset(StdioTransport(command="uvx", args=["mcp-run-python"])),
 #       ],
 #   )
diff --git a/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py 
b/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py
index f6749d280f3..b53ea712e8c 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/hooks/mcp.py
@@ -47,7 +47,7 @@ class MCPHook(BaseHook):
         - **Extra.transport**: Transport type — ``http`` (default), ``sse``, 
or ``stdio``
         - **Extra.command**: Command to run for stdio transport (e.g. ``uvx``)
         - **Extra.args**: Command arguments for stdio transport (e.g. 
``["mcp-run-python"]``)
-        - **Extra.timeout**: Connection timeout in seconds for stdio (default: 
10)
+        - **Extra.timeout**: Connection init timeout in seconds for stdio 
(default: 10)
 
     For HTTP/SSE transports the ``Authorization`` header is, by default, a 
static
     ``Bearer`` token taken from the connection ``password``. Endpoints that 
require
@@ -121,14 +121,18 @@ class MCPHook(BaseHook):
 
     def get_conn(self) -> Any:
         """
-        Return a configured PydanticAI MCP server instance.
+        Return a configured PydanticAI MCP toolset instance.
 
-        Creates the appropriate MCP server based on the transport type
-        in the connection's extra field:
+        Builds a :class:`~pydantic_ai.mcp.MCPToolset` over the FastMCP 
transport
+        matching the transport type in the connection's extra field:
 
-        - ``http`` (default): :class:`~pydantic_ai.mcp.MCPServerStreamableHTTP`
-        - ``sse``: :class:`~pydantic_ai.mcp.MCPServerSSE`
-        - ``stdio``: :class:`~pydantic_ai.mcp.MCPServerStdio`
+        - ``http`` (default): 
``fastmcp.client.transports.StreamableHttpTransport``
+        - ``sse``: ``fastmcp.client.transports.SSETransport``
+        - ``stdio``: ``fastmcp.client.transports.StdioTransport``
+
+        When ``tool_prefix`` is set the toolset is wrapped via
+        :meth:`~pydantic_ai.toolsets.abstract.AbstractToolset.prefixed`, so a
+        prefix of ``"weather"`` yields tool names like 
``weather_get_forecast``.
 
         The result is cached for the lifetime of this hook instance.
         """
@@ -136,7 +140,8 @@ class MCPHook(BaseHook):
             return self._server
 
         try:
-            from pydantic_ai.mcp import MCPServerSSE, MCPServerStdio, 
MCPServerStreamableHTTP
+            from fastmcp.client.transports import SSETransport, 
StdioTransport, StreamableHttpTransport
+            from pydantic_ai.mcp import MCPToolset
         except ImportError:
             raise ImportError(
                 'MCP support requires the `mcp` package. Install it with: pip 
install "pydantic-ai-slim[mcp]"'
@@ -149,15 +154,11 @@ class MCPHook(BaseHook):
         if transport == "http":
             if not conn.host:
                 raise ValueError(f"Connection {self.mcp_conn_id!r} requires a 
host URL for HTTP transport.")
-            self._server = MCPServerStreamableHTTP(
-                conn.host, headers=self._auth_headers(conn), 
tool_prefix=self.tool_prefix
-            )
+            toolset = MCPToolset(StreamableHttpTransport(conn.host, 
headers=self._auth_headers(conn)))
         elif transport == "sse":
             if not conn.host:
                 raise ValueError(f"Connection {self.mcp_conn_id!r} requires a 
host URL for SSE transport.")
-            self._server = MCPServerSSE(
-                conn.host, headers=self._auth_headers(conn), 
tool_prefix=self.tool_prefix
-            )
+            toolset = MCPToolset(SSETransport(conn.host, 
headers=self._auth_headers(conn)))
         elif transport == "stdio":
             command = extra.get("command")
             if not command:
@@ -168,13 +169,14 @@ class MCPHook(BaseHook):
             if isinstance(args, str):
                 args = [args]
             timeout = extra.get("timeout", 10)
-            self._server = MCPServerStdio(command, args=args, timeout=timeout, 
tool_prefix=self.tool_prefix)
+            toolset = MCPToolset(StdioTransport(command=command, args=args), 
init_timeout=timeout)
         else:
             raise ValueError(
                 f"Unknown transport {transport!r} in connection 
{self.mcp_conn_id!r}. "
                 "Supported: 'http', 'sse', 'stdio'."
             )
 
+        self._server = toolset.prefixed(self.tool_prefix) if self.tool_prefix 
else toolset
         return self._server
 
     def test_connection(self) -> tuple[bool, str]:
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/langchain_bridge.py
 
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/langchain_bridge.py
index 3c2765c1a70..35876d82930 100644
--- 
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/langchain_bridge.py
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/langchain_bridge.py
@@ -97,9 +97,9 @@ def airflow_toolset_to_langchain_tools(
         and every ``call_tool`` each run under their own event loop, and 
pydantic-ai
         opens and tears the connection down around each one. For 
``MCPToolset`` this
         means the server is reconnected on every tool call. That is fine for
-        stateless tools (and for HTTP/SSE servers, modulo per-call latency), 
but an
-        ``MCPServerStdio`` server, or any server that keeps state between 
calls,
-        will lose that state because each call starts a fresh process/session.
+        stateless tools (and for HTTP/SSE servers, modulo per-call latency), 
but a
+        stdio server, or any server that keeps state between calls, will lose 
that
+        state because each call starts a fresh process/session.
 
     .. note::
         A pydantic-ai toolset is normally driven inside an agent run, where a
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py 
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py
index d7aae3f9a6a..5a35cc3f54d 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/mcp.py
@@ -35,18 +35,17 @@ class MCPToolset(AbstractToolset[Any]):
 
     Reads MCP server transport type, URL, command, and credentials from the
     connection via :class:`~airflow.providers.common.ai.hooks.mcp.MCPHook` and
-    creates the appropriate PydanticAI MCP server instance.
-    All ``AbstractToolset`` methods delegate to the underlying MCP server.
+    builds the matching PydanticAI :class:`~pydantic_ai.mcp.MCPToolset`.
+    All ``AbstractToolset`` methods delegate to the underlying MCP toolset.
 
     This is the recommended way to use MCP servers in Airflow — it stores
     server configuration in Airflow connections (and secret backends) rather
     than hard-coding URLs and credentials in DAG code.
 
-    If you prefer full PydanticAI control, you can pass MCP server instances
-    directly to ``AgentOperator(toolsets=[...])``, since
-    :class:`~pydantic_ai.mcp.MCPServerStreamableHTTP`,
-    :class:`~pydantic_ai.mcp.MCPServerSSE`, and
-    :class:`~pydantic_ai.mcp.MCPServerStdio` all implement ``AbstractToolset``.
+    If you prefer full PydanticAI control, you can pass a
+    :class:`~pydantic_ai.mcp.MCPToolset` (built over a FastMCP transport)
+    directly to ``AgentOperator(toolsets=[...])``, since it implements
+    ``AbstractToolset``.
 
     For MCP endpoints that need a freshly minted or short-lived token (e.g. a
     Snowflake managed MCP server authenticated with a key-pair JWT, or OAuth /
diff --git a/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py 
b/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py
index 2928376d699..6e7a5629c3f 100644
--- a/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py
+++ b/providers/common/ai/tests/unit/common/ai/hooks/test_mcp.py
@@ -24,11 +24,13 @@ import pytest
 from airflow.models.connection import Connection
 from airflow.providers.common.ai.hooks.mcp import MCPHook
 
-# The hook imports MCP classes lazily inside get_conn(), so we must patch
-# them at their source in pydantic_ai.mcp rather than on the hook module.
-_MCP_HTTP = "pydantic_ai.mcp.MCPServerStreamableHTTP"
-_MCP_SSE = "pydantic_ai.mcp.MCPServerSSE"
-_MCP_STDIO = "pydantic_ai.mcp.MCPServerStdio"
+# The hook imports these lazily inside get_conn(), so we patch them at their
+# source modules rather than on the hook module. MCPToolset is the unified
+# pydantic-ai entrypoint; the three transports come from FastMCP.
+_MCP_TOOLSET = "pydantic_ai.mcp.MCPToolset"
+_HTTP_TRANSPORT = "fastmcp.client.transports.StreamableHttpTransport"
+_SSE_TRANSPORT = "fastmcp.client.transports.SSETransport"
+_STDIO_TRANSPORT = "fastmcp.client.transports.StdioTransport"
 
 
 class TestMCPHookInit:
@@ -46,8 +48,9 @@ class TestMCPHookInit:
 
 
 class TestMCPHookGetConn:
-    @patch(_MCP_HTTP)
-    def test_http_transport(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_HTTP_TRANSPORT, autospec=True)
+    def test_http_transport(self, mock_transport_cls, mock_toolset_cls):
         hook = MCPHook(mcp_conn_id="test_conn")
         conn = Connection(
             conn_id="test_conn",
@@ -57,11 +60,13 @@ class TestMCPHookGetConn:
         with patch.object(hook, "get_connection", return_value=conn):
             result = hook.get_conn()
 
-        mock_server_cls.assert_called_once_with("http://localhost:3001/mcp";, 
headers=None, tool_prefix=None)
-        assert result is mock_server_cls.return_value
+        
mock_transport_cls.assert_called_once_with("http://localhost:3001/mcp";, 
headers=None)
+        
mock_toolset_cls.assert_called_once_with(mock_transport_cls.return_value)
+        assert result is mock_toolset_cls.return_value
 
-    @patch(_MCP_HTTP)
-    def test_http_is_default_transport(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_HTTP_TRANSPORT, autospec=True)
+    def test_http_is_default_transport(self, mock_transport_cls, 
mock_toolset_cls):
         hook = MCPHook(mcp_conn_id="test_conn")
         conn = Connection(
             conn_id="test_conn",
@@ -71,10 +76,12 @@ class TestMCPHookGetConn:
         with patch.object(hook, "get_connection", return_value=conn):
             hook.get_conn()
 
-        mock_server_cls.assert_called_once_with("http://localhost:3001/mcp";, 
headers=None, tool_prefix=None)
+        
mock_transport_cls.assert_called_once_with("http://localhost:3001/mcp";, 
headers=None)
+        
mock_toolset_cls.assert_called_once_with(mock_transport_cls.return_value)
 
-    @patch(_MCP_HTTP)
-    def test_http_with_auth_token(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_HTTP_TRANSPORT, autospec=True)
+    def test_http_with_auth_token(self, mock_transport_cls, mock_toolset_cls):
         hook = MCPHook(mcp_conn_id="test_conn")
         conn = Connection(
             conn_id="test_conn",
@@ -85,14 +92,14 @@ class TestMCPHookGetConn:
         with patch.object(hook, "get_connection", return_value=conn):
             hook.get_conn()
 
-        mock_server_cls.assert_called_once_with(
+        mock_transport_cls.assert_called_once_with(
             "http://localhost:3001/mcp";,
             headers={"Authorization": "Bearer my-secret-token"},
-            tool_prefix=None,
         )
 
-    @patch(_MCP_HTTP)
-    def test_passes_tool_prefix(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_HTTP_TRANSPORT, autospec=True)
+    def test_passes_tool_prefix(self, mock_transport_cls, mock_toolset_cls):
         hook = MCPHook(mcp_conn_id="test_conn", tool_prefix="weather")
         conn = Connection(
             conn_id="test_conn",
@@ -100,14 +107,16 @@ class TestMCPHookGetConn:
             host="http://localhost:3001/mcp";,
         )
         with patch.object(hook, "get_connection", return_value=conn):
-            hook.get_conn()
+            result = hook.get_conn()
 
-        mock_server_cls.assert_called_once_with(
-            "http://localhost:3001/mcp";, headers=None, tool_prefix="weather"
-        )
+        # tool_prefix is applied by wrapping the toolset, not via a 
constructor arg.
+        
mock_toolset_cls.assert_called_once_with(mock_transport_cls.return_value)
+        
mock_toolset_cls.return_value.prefixed.assert_called_once_with("weather")
+        assert result is mock_toolset_cls.return_value.prefixed.return_value
 
-    @patch(_MCP_SSE)
-    def test_sse_transport(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_SSE_TRANSPORT, autospec=True)
+    def test_sse_transport(self, mock_transport_cls, mock_toolset_cls):
         hook = MCPHook(mcp_conn_id="test_conn")
         conn = Connection(
             conn_id="test_conn",
@@ -118,11 +127,13 @@ class TestMCPHookGetConn:
         with patch.object(hook, "get_connection", return_value=conn):
             result = hook.get_conn()
 
-        mock_server_cls.assert_called_once_with("http://localhost:3001/sse";, 
headers=None, tool_prefix=None)
-        assert result is mock_server_cls.return_value
+        
mock_transport_cls.assert_called_once_with("http://localhost:3001/sse";, 
headers=None)
+        
mock_toolset_cls.assert_called_once_with(mock_transport_cls.return_value)
+        assert result is mock_toolset_cls.return_value
 
-    @patch(_MCP_STDIO)
-    def test_stdio_transport(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_STDIO_TRANSPORT, autospec=True)
+    def test_stdio_transport(self, mock_transport_cls, mock_toolset_cls):
         hook = MCPHook(mcp_conn_id="test_conn")
         conn = Connection(
             conn_id="test_conn",
@@ -132,11 +143,13 @@ class TestMCPHookGetConn:
         with patch.object(hook, "get_connection", return_value=conn):
             result = hook.get_conn()
 
-        mock_server_cls.assert_called_once_with("uvx", 
args=["mcp-run-python"], timeout=10, tool_prefix=None)
-        assert result is mock_server_cls.return_value
+        mock_transport_cls.assert_called_once_with(command="uvx", 
args=["mcp-run-python"])
+        
mock_toolset_cls.assert_called_once_with(mock_transport_cls.return_value, 
init_timeout=10)
+        assert result is mock_toolset_cls.return_value
 
-    @patch(_MCP_STDIO)
-    def test_stdio_custom_timeout(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_STDIO_TRANSPORT, autospec=True)
+    def test_stdio_custom_timeout(self, mock_transport_cls, mock_toolset_cls):
         hook = MCPHook(mcp_conn_id="test_conn")
         conn = Connection(
             conn_id="test_conn",
@@ -148,10 +161,12 @@ class TestMCPHookGetConn:
         with patch.object(hook, "get_connection", return_value=conn):
             hook.get_conn()
 
-        mock_server_cls.assert_called_once_with("python", args=["-m", 
"server"], timeout=30, tool_prefix=None)
+        mock_transport_cls.assert_called_once_with(command="python", 
args=["-m", "server"])
+        
mock_toolset_cls.assert_called_once_with(mock_transport_cls.return_value, 
init_timeout=30)
 
-    @patch(_MCP_STDIO)
-    def test_args_string_converted_to_list(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_STDIO_TRANSPORT, autospec=True)
+    def test_args_string_converted_to_list(self, mock_transport_cls, 
mock_toolset_cls):
         hook = MCPHook(mcp_conn_id="test_conn")
         conn = Connection(
             conn_id="test_conn",
@@ -161,7 +176,7 @@ class TestMCPHookGetConn:
         with patch.object(hook, "get_connection", return_value=conn):
             hook.get_conn()
 
-        mock_server_cls.assert_called_once_with("uvx", 
args=["mcp-run-python"], timeout=10, tool_prefix=None)
+        mock_transport_cls.assert_called_once_with(command="uvx", 
args=["mcp-run-python"])
 
     def test_http_without_host_raises(self):
         hook = MCPHook(mcp_conn_id="test_conn")
@@ -203,8 +218,9 @@ class TestMCPHookGetConn:
             with pytest.raises(ValueError, match="Unknown transport"):
                 hook.get_conn()
 
-    @patch(_MCP_HTTP)
-    def test_caches_server(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_HTTP_TRANSPORT, autospec=True)
+    def test_caches_server(self, mock_transport_cls, mock_toolset_cls):
         hook = MCPHook(mcp_conn_id="test_conn")
         conn = Connection(conn_id="test_conn", conn_type="mcp", 
host="http://localhost:3001/mcp";)
         with patch.object(hook, "get_connection", return_value=conn):
@@ -212,12 +228,13 @@ class TestMCPHookGetConn:
             second = hook.get_conn()
 
         assert first is second
-        mock_server_cls.assert_called_once()
+        mock_toolset_cls.assert_called_once()
 
 
 class TestMCPHookTestConnection:
-    @patch(_MCP_HTTP)
-    def test_successful_config(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_HTTP_TRANSPORT, autospec=True)
+    def test_successful_config(self, mock_transport_cls, mock_toolset_cls):
         hook = MCPHook(mcp_conn_id="test_conn")
         conn = Connection(conn_id="test_conn", conn_type="mcp", 
host="http://localhost:3001/mcp";)
         with patch.object(hook, "get_connection", return_value=conn):
@@ -249,21 +266,22 @@ class TestMCPHookUIFieldBehaviour:
 
 
 class TestMCPHookTokenProvider:
-    @patch(_MCP_HTTP)
-    def test_http_uses_token_provider(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_HTTP_TRANSPORT, autospec=True)
+    def test_http_uses_token_provider(self, mock_transport_cls, 
mock_toolset_cls):
         hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda: 
"minted-jwt")
         conn = Connection(conn_id="test_conn", conn_type="mcp", 
host="http://localhost:3001/mcp";)
         with patch.object(hook, "get_connection", return_value=conn):
             hook.get_conn()
 
-        mock_server_cls.assert_called_once_with(
+        mock_transport_cls.assert_called_once_with(
             "http://localhost:3001/mcp";,
             headers={"Authorization": "Bearer minted-jwt"},
-            tool_prefix=None,
         )
 
-    @patch(_MCP_HTTP)
-    def test_token_provider_overrides_static_password(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_HTTP_TRANSPORT, autospec=True)
+    def test_token_provider_overrides_static_password(self, 
mock_transport_cls, mock_toolset_cls):
         hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda: "fresh")
         conn = Connection(
             conn_id="test_conn",
@@ -274,14 +292,14 @@ class TestMCPHookTokenProvider:
         with patch.object(hook, "get_connection", return_value=conn):
             hook.get_conn()
 
-        mock_server_cls.assert_called_once_with(
+        mock_transport_cls.assert_called_once_with(
             "http://localhost:3001/mcp";,
             headers={"Authorization": "Bearer fresh"},
-            tool_prefix=None,
         )
 
-    @patch(_MCP_HTTP)
-    def test_token_provider_called_when_establishing_connection(self, 
mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_HTTP_TRANSPORT, autospec=True)
+    def test_token_provider_called_when_establishing_connection(self, 
mock_transport_cls, mock_toolset_cls):
         provider = MagicMock(return_value="tok")
         hook = MCPHook(mcp_conn_id="test_conn", token_provider=provider)
         conn = Connection(conn_id="test_conn", conn_type="mcp", 
host="http://localhost:3001/mcp";)
@@ -290,22 +308,22 @@ class TestMCPHookTokenProvider:
 
         provider.assert_called_once_with()
 
-    @patch(_MCP_HTTP)
-    def test_masks_minted_token(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_HTTP_TRANSPORT, autospec=True)
+    def test_masks_minted_token(self, mock_transport_cls, mock_toolset_cls):
         """The minted token must be registered with secret masking, like 
conn.password."""
         hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda: 
"minted-jwt")
         conn = Connection(conn_id="test_conn", conn_type="mcp", 
host="http://localhost:3001/mcp";)
         with (
             patch.object(hook, "get_connection", return_value=conn),
-            patch("airflow.providers.common.ai.hooks.mcp.mask_secret") as 
mock_mask,
+            patch("airflow.providers.common.ai.hooks.mcp.mask_secret", 
autospec=True) as mock_mask,
         ):
             hook.get_conn()
 
         mock_mask.assert_called_once_with("minted-jwt")
 
     @pytest.mark.parametrize("bad_token", ["", None], ids=["empty", 
"non_string"])
-    @patch(_MCP_HTTP)
-    def test_invalid_token_raises(self, mock_server_cls, bad_token):
+    def test_invalid_token_raises(self, bad_token):
         """A token_provider returning a non-string or empty value fails 
loud."""
         hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda: 
bad_token)
         conn = Connection(conn_id="test_conn", conn_type="mcp", 
host="http://localhost:3001/mcp";)
@@ -313,8 +331,9 @@ class TestMCPHookTokenProvider:
             with pytest.raises(ValueError, match="must return a non-empty 
string token"):
                 hook.get_conn()
 
-    @patch(_MCP_SSE)
-    def test_sse_uses_token_provider(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_SSE_TRANSPORT, autospec=True)
+    def test_sse_uses_token_provider(self, mock_transport_cls, 
mock_toolset_cls):
         hook = MCPHook(mcp_conn_id="test_conn", token_provider=lambda: 
"minted")
         conn = Connection(
             conn_id="test_conn",
@@ -325,14 +344,14 @@ class TestMCPHookTokenProvider:
         with patch.object(hook, "get_connection", return_value=conn):
             hook.get_conn()
 
-        mock_server_cls.assert_called_once_with(
+        mock_transport_cls.assert_called_once_with(
             "http://localhost:3001/sse";,
             headers={"Authorization": "Bearer minted"},
-            tool_prefix=None,
         )
 
-    @patch(_MCP_STDIO)
-    def test_stdio_does_not_invoke_token_provider(self, mock_server_cls):
+    @patch(_MCP_TOOLSET, autospec=True)
+    @patch(_STDIO_TRANSPORT, autospec=True)
+    def test_stdio_does_not_invoke_token_provider(self, mock_transport_cls, 
mock_toolset_cls):
         """stdio has no HTTP headers, so the token provider must not be 
called."""
         provider = MagicMock(return_value="tok")
         hook = MCPHook(mcp_conn_id="test_conn", token_provider=provider)
@@ -345,4 +364,5 @@ class TestMCPHookTokenProvider:
             hook.get_conn()
 
         provider.assert_not_called()
-        mock_server_cls.assert_called_once_with("uvx", 
args=["mcp-run-python"], timeout=10, tool_prefix=None)
+        mock_transport_cls.assert_called_once_with(command="uvx", 
args=["mcp-run-python"])
+        
mock_toolset_cls.assert_called_once_with(mock_transport_cls.return_value, 
init_timeout=10)

Reply via email to