This is an automated email from the ASF dual-hosted git repository.
kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new db26df7d165 fix: Verify durable cached agent steps match the request
before replay (#68372)
db26df7d165 is described below
commit db26df7d1650f20454ffaa19d16be33fef2f8b13
Author: Kaxil Naik <[email protected]>
AuthorDate: Thu Jun 18 01:05:43 2026 +0100
fix: Verify durable cached agent steps match the request before replay
(#68372)
durable=True cached model responses and tool results under purely
positional keys, so a retry replayed cached steps even when the agent
changed between attempts (prompt tweak, model upgrade, toolset change,
or a deploy landing between retries). The retry silently continued a
different conversation with no warning above DEBUG.
Each cache entry now stores a fingerprint of the request that produced
it (model identity, message history minus per-attempt fields, settings,
and the full ModelRequestParameters; tool name, args, and tool_call_id
for tool steps). On a hit the fingerprint is compared first: a mismatch
logs a warning and re-runs the step live. A divergence invalidates
downstream tool steps too, because a fresh model response mints new
tool_call_ids. Entries written by older provider versions have no
fingerprint and re-run instead of replaying.
---
providers/common/ai/docs/operators/agent.rst | 21 +-
.../providers/common/ai/durable/caching_model.py | 50 ++++-
.../providers/common/ai/durable/caching_toolset.py | 33 +++-
.../providers/common/ai/durable/fingerprint.py | 162 ++++++++++++++++
.../providers/common/ai/durable/step_counter.py | 4 +-
.../airflow/providers/common/ai/durable/storage.py | 77 ++++++--
.../airflow/providers/common/ai/operators/agent.py | 6 +-
.../unit/common/ai/durable/test_caching_model.py | 94 ++++++++-
.../unit/common/ai/durable/test_caching_toolset.py | 78 ++++++--
.../unit/common/ai/durable/test_fingerprint.py | 212 +++++++++++++++++++++
.../common/ai/durable/test_replay_verification.py | 138 ++++++++++++++
.../tests/unit/common/ai/durable/test_storage.py | 135 +++++++++++--
12 files changed, 933 insertions(+), 77 deletions(-)
diff --git a/providers/common/ai/docs/operators/agent.rst
b/providers/common/ai/docs/operators/agent.rst
index bf79c044e78..a79e5110048 100644
--- a/providers/common/ai/docs/operators/agent.rst
+++ b/providers/common/ai/docs/operators/agent.rst
@@ -209,11 +209,28 @@ cache:
**How it works**
1. On first execution, each LLM response and tool result is saved to a JSON
- file as the agent progresses.
+ file as the agent progresses, together with a fingerprint of the request
+ that produced it (model, message history, settings, and tools for LLM
+ steps; tool name, arguments, and call id for tool steps).
2. If the task fails and Airflow retries it, completed steps are loaded from
the cache and returned without calling the model or tool. Steps not yet in
the cache proceed normally.
-3. After successful completion, the cache file is deleted.
+3. Before a step is replayed, its stored fingerprint is compared against the
+ current request. If anything changed between attempts -- the system
+ prompt, the model, the toolset, model settings, or the conversation so
+ far -- the stale entry is discarded, a warning is logged, and the step
+ re-runs live. A divergence also invalidates the steps after it: re-running
+ an LLM step produces fresh tool call ids, so tool results recorded under
+ the old conversation no longer match. A changed agent costs a re-run; it
+ never replays responses that belong to a different conversation.
+4. After successful completion, the cache file is deleted.
+
+Replay verification compares the **requests** sent to models and tools, not
+the code behind them. Editing a tool's implementation between attempts does
+not invalidate an already-cached result for an identical call, and pointing
+``llm_conn_id`` at a different endpoint serving the same model name does not
+invalidate cached responses -- delete the cache file to force a fully fresh
+run.
After the run, a single INFO summary line reports how many steps were
replayed vs executed fresh. Per-step detail is available at DEBUG level.
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py
b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py
index 0b2f85ecb40..18f89439d82 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_model.py
@@ -24,6 +24,8 @@ from typing import TYPE_CHECKING, Any
import structlog
from pydantic_ai.models.wrapper import WrapperModel
+from airflow.providers.common.ai.durable.fingerprint import
fingerprint_model_request
+
log = structlog.get_logger(logger_name="task")
if TYPE_CHECKING:
@@ -41,8 +43,12 @@ class CachingModel(WrapperModel):
Wraps a model to cache responses in ObjectStorage for durable execution.
On each ``request()`` call, checks if a cached response exists for the
- current step index. If so, returns the cached response without calling
- the underlying model. Otherwise, calls the model and caches the response.
+ current step index and was produced by an equivalent request (same model,
+ message history, settings, and tools -- compared via fingerprint). If so,
+ returns the cached response without calling the underlying model.
+ Otherwise, calls the model and caches the response. A fingerprint
+ mismatch means the agent changed between attempts; the stale entry is
+ discarded and the step re-runs live.
"""
storage: DurableStorage = field(repr=False)
@@ -67,15 +73,45 @@ class CachingModel(WrapperModel):
) -> ModelResponse:
step = self.counter.next_step()
key = f"model_step_{step}"
+ # Fingerprint the *prepared* request, not the raw arguments. Concrete
+ # models call ``prepare_request()`` at the start of ``request()`` to
merge
+ # their model-level ``settings`` and apply profile-specific transforms
+ # (thinking resolution, native-tool handling, output-mode defaults)
before
+ # the provider sees the request. Fingerprinting the raw arguments would
+ # miss a change that lives only at the model level -- e.g. a different
+ # temperature or thinking setting on the connection -- and replay a
stale
+ # response. The raw arguments are still passed to
``wrapped.request()``,
+ # which re-runs ``prepare_request()`` itself (it is pure and
idempotent).
+ prepared_settings, prepared_parameters = self.wrapped.prepare_request(
+ model_settings, model_request_parameters
+ )
+ fingerprint = fingerprint_model_request(
+ f"{self.wrapped.system}:{self.wrapped.model_name}",
+ messages,
+ prepared_settings,
+ prepared_parameters,
+ )
- cached = self.storage.load_model_response(key)
+ cached, cached_fingerprint = self.storage.load_model_response(key)
if cached is not None:
- self.counter.replayed_model += 1
- log.debug("Durable: replayed cached model response", step=step)
- return cached
+ if cached_fingerprint == fingerprint:
+ self.counter.replayed_model += 1
+ log.debug("Durable: replayed cached model response", step=step)
+ return cached
+ log.warning(
+ "Durable: cached model response does not match the current
request; "
+ "re-running this step instead of replaying",
+ step=step,
+ reason=(
+ "entry predates fingerprinting or the request could not be
fingerprinted"
+ if fingerprint is None or cached_fingerprint is None
+ else "model, prompt, message history, settings, or tools
changed since "
+ "the previous attempt"
+ ),
+ )
response = await self.wrapped.request(messages, model_settings,
model_request_parameters)
- self.storage.save_model_response(key, response)
+ self.storage.save_model_response(key, response,
fingerprint=fingerprint)
self.counter.cached_model += 1
log.debug("Durable: cached model response", step=step)
return response
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py
b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py
index 2fd58fe78a4..045c98aea1f 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/durable/caching_toolset.py
@@ -24,6 +24,8 @@ from typing import TYPE_CHECKING, Any
import structlog
from pydantic_ai.toolsets.wrapper import WrapperToolset
+from airflow.providers.common.ai.durable.fingerprint import
fingerprint_tool_call
+
if TYPE_CHECKING:
from pydantic_ai.toolsets.abstract import ToolsetTool
@@ -39,8 +41,12 @@ class CachingToolset(WrapperToolset[Any]):
Wraps a toolset to cache tool call results in ObjectStorage for durable
execution.
On each ``call_tool()`` invocation, checks if a cached result exists for
- the current step index. If so, returns the cached result without executing
- the tool. Otherwise, executes the tool and caches the result.
+ the current step index and was produced by the same call (same tool name,
+ arguments, and model-issued ``tool_call_id`` -- compared via fingerprint).
+ If so, returns the cached result without executing the tool. Otherwise,
+ executes the tool and caches the result. A fingerprint mismatch means the
+ conversation diverged from the previous attempt; the stale entry is
+ discarded and the tool runs live.
The step index is grabbed before the first ``await``, so parallel tool
calls via ``asyncio.gather`` get deterministic indices (tasks start
@@ -61,15 +67,28 @@ class CachingToolset(WrapperToolset[Any]):
# even when multiple tool calls run concurrently via asyncio.gather.
step = self.counter.next_step()
key = f"tool_step_{step}"
+ fingerprint = fingerprint_tool_call(name, tool_args, ctx.tool_call_id)
- found, cached = self.storage.load_tool_result(key)
+ found, cached, cached_fingerprint = self.storage.load_tool_result(key)
if found:
- self.counter.replayed_tool += 1
- log.debug("Durable: replayed cached tool result", step=step,
tool=name)
- return cached
+ if cached_fingerprint == fingerprint:
+ self.counter.replayed_tool += 1
+ log.debug("Durable: replayed cached tool result", step=step,
tool=name)
+ return cached
+ log.warning(
+ "Durable: cached tool result does not match the current tool
call; "
+ "re-running the tool instead of replaying",
+ step=step,
+ tool=name,
+ reason=(
+ "entry predates fingerprinting or the call could not be
fingerprinted"
+ if fingerprint is None or cached_fingerprint is None
+ else "the conversation diverged from the previous attempt"
+ ),
+ )
result = await self.wrapped.call_tool(name, tool_args, ctx, tool)
- self.storage.save_tool_result(key, result)
+ self.storage.save_tool_result(key, result, fingerprint=fingerprint)
self.counter.cached_tool += 1
log.debug("Durable: cached tool result", step=step, tool=name)
return result
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/durable/fingerprint.py
b/providers/common/ai/src/airflow/providers/common/ai/durable/fingerprint.py
new file mode 100644
index 00000000000..6b5a2d490a5
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/durable/fingerprint.py
@@ -0,0 +1,162 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Request fingerprints for durable replay verification.
+
+Durable caching keys steps positionally (``model_step_{N}`` /
``tool_step_{N}``).
+Position alone cannot tell whether a cached entry still corresponds to the
+current request: if the prompt, model, toolset, or message history changed
+between the failed attempt and the retry, replaying by position would feed the
+agent responses recorded for a different conversation.
+
+Each cache entry therefore stores a fingerprint of the request that produced
+it. On a cache hit the stored fingerprint is compared against the current
+request; a mismatch is treated as a cache miss and the step re-runs live.
+A divergence invalidates downstream steps too: a fresh model response carries
+newly generated ``tool_call_id`` values, which are part of the tool
+fingerprint, so stale tool results recorded under the old conversation no
+longer match.
+
+Fields that pydantic-ai regenerates on every attempt (message-level
+``timestamp``/``run_id``/``conversation_id`` and part-level ``timestamp``)
+are excluded from the fingerprint. Requests that cannot be serialized to
+JSON fingerprint as ``None``, which degrades that step to unverified
+positional replay (the pre-fingerprint behavior) rather than disabling
+caching.
+"""
+
+from __future__ import annotations
+
+import hashlib
+import json
+from typing import TYPE_CHECKING, Any
+
+import structlog
+from pydantic import TypeAdapter
+from pydantic_ai.messages import ModelMessagesTypeAdapter
+from pydantic_ai.models import ModelRequestParameters
+
+if TYPE_CHECKING:
+ from pydantic_ai.messages import ModelMessage
+ from pydantic_ai.settings import ModelSettings
+
+log = structlog.get_logger(logger_name="task")
+
+_MODEL_REQUEST_PARAMETERS_ADAPTER = TypeAdapter(ModelRequestParameters)
+
+# Message-level fields regenerated on every attempt.
+_VOLATILE_MESSAGE_KEYS = ("timestamp", "run_id", "conversation_id")
+
+# Settings that control transport, not response content. Excluded from the
+# fingerprint: changing them should not invalidate a cached response, and some
+# (``timeout`` can be an ``httpx.Timeout``) are not JSON-serializable, which
+# would otherwise force the whole fingerprint to ``None`` and silently disable
+# replay verification for every step.
+_TRANSPORT_ONLY_SETTINGS = frozenset({"timeout"})
+
+
+def _content_settings(model_settings: ModelSettings | None) -> dict[str, Any]
| None:
+ """Return the content-affecting settings, or ``None`` if there are none."""
+ if not model_settings:
+ return None
+ content = {k: v for k, v in model_settings.items() if k not in
_TRANSPORT_ONLY_SETTINGS}
+ return content or None
+
+
+def _strip_volatile(messages_dump: list[dict[str, Any]]) -> list[dict[str,
Any]]:
+ """
+ Drop per-attempt fields from a dumped message list.
+
+ Only the levels pydantic-ai regenerates are touched (message-level ids and
+ timestamps, part-level timestamps); user data such as tool arguments is
+ never recursed into, so an argument legitimately named ``run_id`` still
+ affects the fingerprint.
+ """
+ stripped = []
+ for message in messages_dump:
+ cleaned = {k: v for k, v in message.items() if k not in
_VOLATILE_MESSAGE_KEYS}
+ if isinstance(cleaned.get("parts"), list):
+ cleaned["parts"] = [
+ {k: v for k, v in part.items() if k != "timestamp"} if
isinstance(part, dict) else part
+ for part in cleaned["parts"]
+ ]
+ stripped.append(cleaned)
+ return stripped
+
+
+def _digest(payload: Any) -> str:
+ # No ``default=`` fallback: a non-JSON-serializable value must raise so the
+ # callers degrade to an unverifiable (None) fingerprint instead of hashing
+ # process-local reprs like ``<object at 0x...>`` that never match on retry.
+ canonical = json.dumps(payload, sort_keys=True)
+ return hashlib.sha256(canonical.encode()).hexdigest()
+
+
+def fingerprint_model_request(
+ model_identifier: str,
+ messages: list[ModelMessage],
+ model_settings: ModelSettings | None,
+ model_request_parameters: ModelRequestParameters,
+) -> str | None:
+ """
+ Fingerprint a model request: model identity, message history, settings,
and request parameters.
+
+ The full ``ModelRequestParameters`` object is hashed (tool definitions,
+ output mode and schema, native tools, ...) so any change to what is sent
+ to the model invalidates the cached response.
+
+ Returns ``None`` when the request cannot be serialized; ``None`` compares
+ equal to ``None``, so requests that cannot be fingerprinted degrade to
+ unverified positional replay rather than disabling caching.
+ """
+ try:
+ dumped = ModelMessagesTypeAdapter.dump_python(messages, mode="json")
+ params =
_MODEL_REQUEST_PARAMETERS_ADAPTER.dump_python(model_request_parameters,
mode="json")
+ return _digest(
+ {
+ "model": model_identifier,
+ "messages": _strip_volatile(dumped),
+ "settings": _content_settings(model_settings),
+ "params": params,
+ }
+ )
+ except (TypeError, ValueError):
+ # TypeError from json.dumps, ValueError covers
PydanticSerializationError
+ log.warning(
+ "Durable: could not fingerprint model request; cached responses
for this "
+ "step replay without verification"
+ )
+ return None
+
+
+def fingerprint_tool_call(name: str, tool_args: dict[str, Any], tool_call_id:
str | None) -> str | None:
+ """
+ Fingerprint a tool call: tool name, arguments, and the model-issued call
id.
+
+ ``tool_call_id`` round-trips through the model-response cache, so it is
+ stable under faithful replay but regenerated whenever a live model call
+ replaces a cached response -- chaining invalidation to downstream tool
steps.
+ """
+ try:
+ return _digest({"name": name, "args": tool_args, "tool_call_id":
tool_call_id})
+ except (TypeError, ValueError):
+ log.warning(
+ "Durable: could not fingerprint tool call; cached results for this
"
+ "step replay without verification",
+ tool=name,
+ )
+ return None
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/durable/step_counter.py
b/providers/common/ai/src/airflow/providers/common/ai/durable/step_counter.py
index 85643f9b34a..a32b4a1ee3d 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/durable/step_counter.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/durable/step_counter.py
@@ -24,7 +24,9 @@ class DurableStepCounter:
Monotonically increasing counter shared between CachingModel and
CachingToolset.
Each model call and tool call increments the counter. The step index
- is used as the cache key, ensuring deterministic replay on retry.
+ is used as the cache key; replay correctness is verified separately by
+ comparing the request fingerprint stored with each cache entry (see
+ ``airflow.providers.common.ai.durable.fingerprint``).
"""
def __init__(self) -> None:
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py
b/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py
index d4437b18d11..88de494fe91 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/durable/storage.py
@@ -99,24 +99,50 @@ class DurableStorage:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(self._cache))
- def save_model_response(self, key: str, response: ModelResponse) -> None:
- """Serialize and store a ModelResponse in the cache."""
+ def save_model_response(self, key: str, response: ModelResponse, *,
fingerprint: str | None) -> None:
+ """Serialize and store a ModelResponse with the request fingerprint
that produced it."""
cache = self._load_cache()
- cache[key] = ModelMessagesTypeAdapter.dump_json([response]).decode()
+ # Store the dumped messages as native JSON-compatible objects, not a
+ # pre-encoded string: the whole cache is JSON-encoded once in
+ # ``_save_cache``, so embedding a string here would double-encode the
+ # (large) response payload.
+ cache[key] = {
+ "fingerprint": fingerprint,
+ "data": ModelMessagesTypeAdapter.dump_python([response],
mode="json"),
+ }
self._save_cache()
- def load_model_response(self, key: str) -> ModelResponse | None:
- """Load a cached ModelResponse, or return None if not cached."""
+ def load_model_response(self, key: str) -> tuple[ModelResponse | None, str
| None]:
+ """
+ Load a cached ModelResponse and its stored request fingerprint.
+
+ Returns ``(None, None)`` if not cached. Entries written before
+ fingerprints existed load with a ``None`` fingerprint.
+ """
cache = self._load_cache()
raw = cache.get(key)
if raw is None:
- return None
- messages = ModelMessagesTypeAdapter.validate_json(raw)
- return messages[0] # type: ignore[return-value]
-
- def save_tool_result(self, key: str, result: Any) -> None:
+ return None, None
+ try:
+ if isinstance(raw, dict):
+ messages =
ModelMessagesTypeAdapter.validate_python(raw["data"])
+ fingerprint = raw.get("fingerprint")
+ else:
+ # Legacy entry: the adapter JSON (a list) was stored directly
as a string.
+ messages = ModelMessagesTypeAdapter.validate_json(raw)
+ fingerprint = None
+ except (KeyError, IndexError, ValueError):
+ # A torn or malformed entry degrades to a miss (the step re-runs),
+ # never a task crash -- the cache is best-effort.
+ log.warning("Durable: ignoring malformed cached model response",
key=key)
+ return None, None
+ if not messages:
+ return None, None
+ return messages[0], fingerprint # type: ignore[return-value]
+
+ def save_tool_result(self, key: str, result: Any, *, fingerprint: str |
None) -> None:
"""
- Store a tool call result in the cache.
+ Store a tool call result with the call fingerprint that produced it.
Non-serializable results (e.g. BinaryContent from MCP tools) are
skipped with a warning -- the tool call still succeeds, but won't
@@ -124,30 +150,39 @@ class DurableStorage:
"""
cache = self._load_cache()
try:
- cache[key] = json.dumps({_SENTINEL: True, "value": result})
- except TypeError:
+ # Probe serializability before mutating the shared cache: a
+ # non-serializable result must skip only this entry, not break the
+ # whole-file ``_save_cache``. TypeError covers unsupported types;
+ # ValueError covers circular references.
+ json.dumps(result)
+ except (TypeError, ValueError):
log.warning(
"Durable: skipping cache for non-serializable tool result",
key=key,
type=type(result).__name__,
)
return
+ cache[key] = {_SENTINEL: True, "value": result, "fingerprint":
fingerprint}
self._save_cache()
- def load_tool_result(self, key: str) -> tuple[bool, Any]:
+ def load_tool_result(self, key: str) -> tuple[bool, Any, str | None]:
"""
- Load a cached tool result.
+ Load a cached tool result and its stored call fingerprint.
- Returns (found, value) tuple since the cached value itself could be
None.
+ Returns a (found, value, fingerprint) tuple since the cached value
+ itself could be None. Entries written before fingerprints existed
+ load with a ``None`` fingerprint.
"""
cache = self._load_cache()
raw = cache.get(key)
if raw is None:
- return False, None
- parsed = json.loads(raw)
- if not isinstance(parsed, dict) or _SENTINEL not in parsed:
- return False, None
- return True, parsed["value"]
+ return False, None, None
+ # Legacy entries were stored as a JSON string; new entries are native
dicts.
+ if isinstance(raw, str):
+ raw = json.loads(raw)
+ if not isinstance(raw, dict) or _SENTINEL not in raw:
+ return False, None, None
+ return True, raw["value"], raw.get("fingerprint")
def cleanup(self) -> None:
"""Delete the cache file after successful execution."""
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
index 6468bfeaa40..bda06ea1f56 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
@@ -148,7 +148,11 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
or tool budget. ``None`` (default) means no enforcement.
:param durable: When ``True``, enables step-level caching of model
responses and tool results for durable execution. On retry, cached
- steps are replayed instead of re-executing. Default ``False``.
+ steps are replayed instead of re-executing. Each cached step is
+ verified against the current request before replay: if the prompt,
+ model, settings, tools, or message history changed since the failed
+ attempt, the affected steps re-run live (with a warning) instead of
+ replaying stale results. Default ``False``.
Requires ``[common.ai] durable_cache_path`` to be set.
:param code_mode: When ``True``, wraps the agent's tools in a single
``run_code`` tool powered by the Monty sandbox (pydantic-ai-harness
diff --git
a/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py
b/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py
index 117f5d4a08c..2b00fa4a408 100644
--- a/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py
+++ b/providers/common/ai/tests/unit/common/ai/durable/test_caching_model.py
@@ -20,15 +20,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic_ai.messages import ModelResponse, TextPart
+from pydantic_ai.models import ModelRequestParameters
from airflow.providers.common.ai.durable.caching_model import CachingModel
+from airflow.providers.common.ai.durable.fingerprint import
fingerprint_model_request
from airflow.providers.common.ai.durable.step_counter import DurableStepCounter
@pytest.fixture
def mock_storage():
storage = MagicMock()
- storage.load_model_response.return_value = None
+ storage.load_model_response.return_value = (None, None)
return storage
@@ -44,6 +46,8 @@ def mock_model():
model.system = "test"
model.profile = MagicMock()
model.settings = None
+ # CachingModel fingerprints the prepared request; identity keeps prepared
== raw.
+ model.prepare_request = lambda settings, params: (settings, params)
return model
@@ -59,15 +63,22 @@ def sample_response():
return ModelResponse(parts=[TextPart(content="Hello!")])
+def request_fingerprint(messages=(), settings=None, params=None):
+ """Fingerprint matching what CachingModel computes for the mock model."""
+ return fingerprint_model_request(
+ "test:test-model", list(messages), settings, params or
ModelRequestParameters()
+ )
+
+
class TestCachingModelCacheHit:
@pytest.mark.asyncio
async def test_returns_cached_response_without_calling_model(
self, mock_model, mock_storage, counter, sample_response
):
- mock_storage.load_model_response.return_value = sample_response
+ mock_storage.load_model_response.return_value = (sample_response,
request_fingerprint())
caching = CachingModel(mock_model, storage=mock_storage,
counter=counter)
- result = await caching.request([], None, MagicMock())
+ result = await caching.request([], None, ModelRequestParameters())
assert result is sample_response
mock_model.request.assert_not_called()
@@ -75,10 +86,10 @@ class TestCachingModelCacheHit:
@pytest.mark.asyncio
async def test_advances_counter_on_cache_hit(self, mock_model,
mock_storage, counter, sample_response):
- mock_storage.load_model_response.return_value = sample_response
+ mock_storage.load_model_response.return_value = (sample_response,
request_fingerprint())
caching = CachingModel(mock_model, storage=mock_storage,
counter=counter)
- await caching.request([], None, MagicMock())
+ await caching.request([], None, ModelRequestParameters())
assert counter.total_steps == 1
@@ -89,11 +100,13 @@ class TestCachingModelCacheMiss:
mock_model.request = AsyncMock(return_value=sample_response)
caching = CachingModel(mock_model, storage=mock_storage,
counter=counter)
- result = await caching.request([], None, MagicMock())
+ result = await caching.request([], None, ModelRequestParameters())
assert result is sample_response
mock_model.request.assert_called_once()
-
mock_storage.save_model_response.assert_called_once_with("model_step_0",
sample_response)
+ mock_storage.save_model_response.assert_called_once_with(
+ "model_step_0", sample_response, fingerprint=request_fingerprint()
+ )
@pytest.mark.asyncio
async def test_sequential_calls_use_incrementing_keys(self, mock_model,
mock_storage, counter):
@@ -102,8 +115,71 @@ class TestCachingModelCacheMiss:
mock_model.request = AsyncMock(side_effect=[response_1, response_2])
caching = CachingModel(mock_model, storage=mock_storage,
counter=counter)
- await caching.request([], None, MagicMock())
- await caching.request([], None, MagicMock())
+ await caching.request([], None, ModelRequestParameters())
+ await caching.request([], None, ModelRequestParameters())
keys = [call[0][0] for call in
mock_storage.save_model_response.call_args_list]
assert keys == ["model_step_0", "model_step_1"]
+
+
+class TestCachingModelReplayVerification:
+ @pytest.mark.asyncio
+ async def test_fingerprint_mismatch_treated_as_miss(
+ self, mock_model, mock_storage, counter, sample_response
+ ):
+ """A cached entry recorded for a different request must not be
replayed."""
+ stale = ModelResponse(parts=[TextPart(content="stale")])
+ mock_storage.load_model_response.return_value = (stale,
"fp_of_old_conversation")
+ mock_model.request = AsyncMock(return_value=sample_response)
+ caching = CachingModel(mock_model, storage=mock_storage,
counter=counter)
+
+ result = await caching.request([], None, ModelRequestParameters())
+
+ assert result is sample_response
+ mock_model.request.assert_called_once()
+ assert counter.replayed_model == 0
+ mock_storage.save_model_response.assert_called_once_with(
+ "model_step_0", sample_response, fingerprint=request_fingerprint()
+ )
+
+ @pytest.mark.asyncio
+ async def test_legacy_entry_without_fingerprint_treated_as_miss(
+ self, mock_model, mock_storage, counter, sample_response
+ ):
+ """Pre-fingerprint cache entries cannot be verified, so they re-run."""
+ stale = ModelResponse(parts=[TextPart(content="stale")])
+ mock_storage.load_model_response.return_value = (stale, None)
+ mock_model.request = AsyncMock(return_value=sample_response)
+ caching = CachingModel(mock_model, storage=mock_storage,
counter=counter)
+
+ result = await caching.request([], None, ModelRequestParameters())
+
+ assert result is sample_response
+ mock_model.request.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_fingerprint_uses_prepared_request_not_raw_arguments(
+ self, mock_storage, counter, sample_response
+ ):
+ """Concrete models merge model-level settings in ``prepare_request``
before the
+ provider sees the request. The fingerprint must reflect the prepared
settings,
+ so a model-level change (e.g. a different temperature on the
connection) is not
+ invisible behind identical raw ``request()`` arguments."""
+ model = MagicMock()
+ model.model_name = "test-model"
+ model.system = "test"
+ model.profile = MagicMock()
+ model.settings = None
+ model.request = AsyncMock(return_value=sample_response)
+ # Simulate prepare_request merging a model-level temperature into
settings.
+ model.prepare_request = lambda settings, params: ({"temperature":
0.9}, params)
+ caching = CachingModel(model, storage=mock_storage, counter=counter)
+
+ await caching.request([], None, ModelRequestParameters())
+
+ stored_fingerprint =
mock_storage.save_model_response.call_args.kwargs["fingerprint"]
+ # Reflects the prepared settings, not the raw ``None`` the agent
passed in.
+ assert stored_fingerprint == fingerprint_model_request(
+ "test:test-model", [], {"temperature": 0.9},
ModelRequestParameters()
+ )
+ assert stored_fingerprint != request_fingerprint()
diff --git
a/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py
b/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py
index d2d999d9218..d7104928a14 100644
--- a/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py
+++ b/providers/common/ai/tests/unit/common/ai/durable/test_caching_toolset.py
@@ -16,21 +16,24 @@
# under the License.
from __future__ import annotations
+from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from pydantic_ai.messages import ModelResponse, TextPart
+from pydantic_ai.models import ModelRequestParameters
from airflow.providers.common.ai.durable.caching_model import CachingModel
from airflow.providers.common.ai.durable.caching_toolset import CachingToolset
+from airflow.providers.common.ai.durable.fingerprint import
fingerprint_tool_call
from airflow.providers.common.ai.durable.step_counter import DurableStepCounter
@pytest.fixture
def mock_storage():
storage = MagicMock()
- storage.load_tool_result.return_value = (False, None)
- storage.load_model_response.return_value = None
+ storage.load_tool_result.return_value = (False, None, None)
+ storage.load_model_response.return_value = (None, None)
return storage
@@ -49,13 +52,18 @@ def mock_toolset():
return toolset
+def ctx_for(tool_call_id: str | None = "call_1") -> SimpleNamespace:
+ return SimpleNamespace(tool_call_id=tool_call_id)
+
+
class TestCachingToolsetCacheHit:
@pytest.mark.asyncio
async def test_returns_cached_result_without_calling_tool(self,
mock_toolset, mock_storage, counter):
- mock_storage.load_tool_result.return_value = (True, "cached result")
+ fingerprint = fingerprint_tool_call("search", {"q": "foo"}, "call_1")
+ mock_storage.load_tool_result.return_value = (True, "cached result",
fingerprint)
caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage,
counter=counter)
- result = await caching.call_tool("search", {"q": "foo"}, MagicMock(),
MagicMock())
+ result = await caching.call_tool("search", {"q": "foo"},
ctx_for("call_1"), MagicMock())
assert result == "cached result"
mock_toolset.call_tool.assert_not_called()
@@ -63,10 +71,11 @@ class TestCachingToolsetCacheHit:
@pytest.mark.asyncio
async def test_advances_counter_on_cache_hit(self, mock_toolset,
mock_storage, counter):
- mock_storage.load_tool_result.return_value = (True, "cached")
+ fingerprint = fingerprint_tool_call("search", {}, "call_1")
+ mock_storage.load_tool_result.return_value = (True, "cached",
fingerprint)
caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage,
counter=counter)
- await caching.call_tool("search", {}, MagicMock(), MagicMock())
+ await caching.call_tool("search", {}, ctx_for("call_1"), MagicMock())
assert counter.total_steps == 1
@@ -76,24 +85,66 @@ class TestCachingToolsetCacheMiss:
async def test_calls_tool_and_caches_on_miss(self, mock_toolset,
mock_storage, counter):
caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage,
counter=counter)
- result = await caching.call_tool("search", {"q": "foo"}, MagicMock(),
MagicMock())
+ result = await caching.call_tool("search", {"q": "foo"},
ctx_for("call_1"), MagicMock())
assert result == "fresh result"
mock_toolset.call_tool.assert_called_once()
- mock_storage.save_tool_result.assert_called_once_with("tool_step_0",
"fresh result")
+ mock_storage.save_tool_result.assert_called_once_with(
+ "tool_step_0", "fresh result",
fingerprint=fingerprint_tool_call("search", {"q": "foo"}, "call_1")
+ )
@pytest.mark.asyncio
async def test_sequential_calls_use_incrementing_keys(self, mock_toolset,
mock_storage, counter):
mock_toolset.call_tool = AsyncMock(side_effect=["result_a",
"result_b"])
caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage,
counter=counter)
- await caching.call_tool("tool_a", {}, MagicMock(), MagicMock())
- await caching.call_tool("tool_b", {}, MagicMock(), MagicMock())
+ await caching.call_tool("tool_a", {}, ctx_for(), MagicMock())
+ await caching.call_tool("tool_b", {}, ctx_for(), MagicMock())
keys = [call[0][0] for call in
mock_storage.save_tool_result.call_args_list]
assert keys == ["tool_step_0", "tool_step_1"]
+class TestCachingToolsetReplayVerification:
+ @pytest.mark.asyncio
+ async def test_different_tool_call_treated_as_miss(self, mock_toolset,
mock_storage, counter):
+ """A cached result recorded for a different tool call must not be
replayed."""
+ stale_fingerprint = fingerprint_tool_call("lookup_order", {"id":
"A1"}, "old_call")
+ mock_storage.load_tool_result.return_value = (True, "stale result",
stale_fingerprint)
+ caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage,
counter=counter)
+
+ result = await caching.call_tool("charge_card", {"amount": 5},
ctx_for("new_call"), MagicMock())
+
+ assert result == "fresh result"
+ mock_toolset.call_tool.assert_called_once()
+ assert counter.replayed_tool == 0
+
+ @pytest.mark.asyncio
+ async def test_changed_tool_call_id_treated_as_miss(self, mock_toolset,
mock_storage, counter):
+ """Same name/args but a new model-issued call id means the
conversation diverged."""
+ stale_fingerprint = fingerprint_tool_call("search", {"q": "foo"},
"old_call")
+ mock_storage.load_tool_result.return_value = (True, "stale result",
stale_fingerprint)
+ caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage,
counter=counter)
+
+ result = await caching.call_tool("search", {"q": "foo"},
ctx_for("new_call"), MagicMock())
+
+ assert result == "fresh result"
+ mock_toolset.call_tool.assert_called_once()
+
+ @pytest.mark.asyncio
+ async def test_legacy_entry_without_fingerprint_treated_as_miss(
+ self, mock_toolset, mock_storage, counter
+ ):
+ """Pre-fingerprint cache entries cannot be verified, so the tool
re-runs."""
+ mock_storage.load_tool_result.return_value = (True, "stale result",
None)
+ caching = CachingToolset(wrapped=mock_toolset, storage=mock_storage,
counter=counter)
+
+ result = await caching.call_tool("search", {"q": "foo"},
ctx_for("call_1"), MagicMock())
+
+ assert result == "fresh result"
+ mock_toolset.call_tool.assert_called_once()
+
+
class TestSharedCounter:
@pytest.mark.asyncio
async def test_model_and_toolset_share_counter(self, mock_toolset,
mock_storage):
@@ -105,6 +156,7 @@ class TestSharedCounter:
mock_model.system = "test"
mock_model.profile = MagicMock()
mock_model.settings = None
+ mock_model.prepare_request = lambda settings, params: (settings,
params)
response = ModelResponse(parts=[TextPart(content="response")])
mock_model.request = AsyncMock(return_value=response)
@@ -114,9 +166,9 @@ class TestSharedCounter:
caching_toolset = CachingToolset(wrapped=mock_toolset,
storage=mock_storage, counter=counter)
# Simulate: model call -> tool call -> model call
- await caching_model.request([], None, MagicMock())
- await caching_toolset.call_tool("search", {}, MagicMock(), MagicMock())
- await caching_model.request([], None, MagicMock())
+ await caching_model.request([], None, ModelRequestParameters())
+ await caching_toolset.call_tool("search", {}, ctx_for(), MagicMock())
+ await caching_model.request([], None, ModelRequestParameters())
model_keys = [call[0][0] for call in
mock_storage.save_model_response.call_args_list]
tool_keys = [call[0][0] for call in
mock_storage.save_tool_result.call_args_list]
diff --git
a/providers/common/ai/tests/unit/common/ai/durable/test_fingerprint.py
b/providers/common/ai/tests/unit/common/ai/durable/test_fingerprint.py
new file mode 100644
index 00000000000..94e517d1cc3
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/durable/test_fingerprint.py
@@ -0,0 +1,212 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import datetime
+
+import httpx
+from pydantic_ai.messages import (
+ ModelRequest,
+ ModelResponse,
+ SystemPromptPart,
+ ToolCallPart,
+ UserPromptPart,
+)
+from pydantic_ai.models import ModelRequestParameters
+from pydantic_ai.tools import ToolDefinition
+
+from airflow.providers.common.ai.durable.fingerprint import (
+ fingerprint_model_request,
+ fingerprint_tool_call,
+)
+
+
+def make_messages(system: str = "You are a bot.", user: str = "hello",
**part_kwargs):
+ return [
+ ModelRequest(
+ parts=[
+ SystemPromptPart(content=system, **part_kwargs),
+ UserPromptPart(content=user, **part_kwargs),
+ ]
+ )
+ ]
+
+
+class TestModelRequestFingerprint:
+ def test_stable_across_part_timestamps(self):
+ """Part timestamps regenerate on every attempt and must not affect the
fingerprint."""
+ t1 = datetime.datetime(2026, 1, 1, tzinfo=datetime.timezone.utc)
+ t2 = datetime.datetime(2026, 1, 2, tzinfo=datetime.timezone.utc)
+ fp1 = fingerprint_model_request("m", make_messages(timestamp=t1),
None, ModelRequestParameters())
+ fp2 = fingerprint_model_request("m", make_messages(timestamp=t2),
None, ModelRequestParameters())
+
+ assert fp1 == fp2
+
+ def test_stable_across_separate_message_constructions(self):
+ """run_id/conversation_id and other per-run fields must not affect the
fingerprint."""
+ fp1 = fingerprint_model_request("m", make_messages(), None,
ModelRequestParameters())
+ fp2 = fingerprint_model_request("m", make_messages(), None,
ModelRequestParameters())
+
+ assert fp1 == fp2
+
+ def test_changes_with_system_prompt(self):
+ fp1 = fingerprint_model_request("m", make_messages(system="a"), None,
ModelRequestParameters())
+ fp2 = fingerprint_model_request("m", make_messages(system="b"), None,
ModelRequestParameters())
+
+ assert fp1 != fp2
+
+ def test_changes_with_user_prompt(self):
+ fp1 = fingerprint_model_request("m", make_messages(user="a"), None,
ModelRequestParameters())
+ fp2 = fingerprint_model_request("m", make_messages(user="b"), None,
ModelRequestParameters())
+
+ assert fp1 != fp2
+
+ def test_changes_with_model_identifier(self):
+ fp1 = fingerprint_model_request("openai:gpt-5", make_messages(), None,
ModelRequestParameters())
+ fp2 = fingerprint_model_request("openai:gpt-5-mini", make_messages(),
None, ModelRequestParameters())
+
+ assert fp1 != fp2
+
+ def test_changes_with_model_settings(self):
+ fp1 = fingerprint_model_request("m", make_messages(), None,
ModelRequestParameters())
+ fp2 = fingerprint_model_request("m", make_messages(), {"temperature":
0.5}, ModelRequestParameters())
+
+ assert fp1 != fp2
+
+ def test_changes_with_toolset(self):
+ tool = ToolDefinition(name="search", parameters_json_schema={"type":
"object"})
+ fp1 = fingerprint_model_request("m", make_messages(), None,
ModelRequestParameters())
+ fp2 = fingerprint_model_request(
+ "m", make_messages(), None,
ModelRequestParameters(function_tools=[tool])
+ )
+
+ assert fp1 != fp2
+
+ def test_changes_with_output_mode(self):
+ """The full request parameters are hashed, not just the tool list."""
+ fp1 = fingerprint_model_request("m", make_messages(), None,
ModelRequestParameters())
+ fp2 = fingerprint_model_request(
+ "m", make_messages(), None,
ModelRequestParameters(output_mode="native")
+ )
+
+ assert fp1 != fp2
+
+ def test_changes_with_tool_definition_fields(self):
+ """Changes inside a tool definition (e.g. strict mode) affect the
fingerprint."""
+ strict = ToolDefinition(name="t", parameters_json_schema={"type":
"object"}, strict=True)
+ lax = ToolDefinition(name="t", parameters_json_schema={"type":
"object"}, strict=False)
+ fp1 = fingerprint_model_request(
+ "m", make_messages(), None,
ModelRequestParameters(function_tools=[strict])
+ )
+ fp2 = fingerprint_model_request(
+ "m", make_messages(), None,
ModelRequestParameters(function_tools=[lax])
+ )
+
+ assert fp1 != fp2
+
+ def test_volatile_keys_inside_user_data_are_not_stripped(self):
+ """Only pydantic-ai's own message/part fields are volatile; a tool
argument
+ legitimately named run_id must still affect the fingerprint."""
+
+ def messages_with_args(args):
+ return [
+ ModelRequest(parts=[UserPromptPart(content="q")]),
+ ModelResponse(parts=[ToolCallPart(tool_name="t", args=args,
tool_call_id="id1")]),
+ ]
+
+ fp1 = fingerprint_model_request(
+ "m", messages_with_args({"run_id": "a"}), None,
ModelRequestParameters()
+ )
+ fp2 = fingerprint_model_request(
+ "m", messages_with_args({"run_id": "b"}), None,
ModelRequestParameters()
+ )
+
+ assert fp1 != fp2
+
+ def test_unserializable_request_returns_none(self):
+ fp = fingerprint_model_request("m", [object()], None,
ModelRequestParameters()) # type: ignore[list-item]
+
+ assert fp is None
+
+ def test_unserializable_settings_returns_none(self):
+ """Non-JSON settings values degrade to unverified replay instead of
hashing
+ process-local reprs that would never match on retry."""
+ fp = fingerprint_model_request(
+ "m", make_messages(), {"extra_body": object()},
ModelRequestParameters()
+ ) # type: ignore[typeddict-item]
+
+ assert fp is None
+
+ def test_httpx_timeout_does_not_disable_fingerprint(self):
+ """``timeout`` may be an ``httpx.Timeout`` (a supported, non-JSON
shape).
+ It must not force the fingerprint to None and silently disable
verification."""
+ fp = fingerprint_model_request(
+ "m", make_messages(), {"timeout": httpx.Timeout(30.0)},
ModelRequestParameters()
+ )
+
+ assert fp is not None
+
+ def test_timeout_excluded_from_fingerprint(self):
+ """timeout is transport-only -- changing it (or its type) must not
invalidate
+ the cached response, so it is the same fingerprint as no timeout at
all."""
+ no_timeout = fingerprint_model_request("m", make_messages(), None,
ModelRequestParameters())
+ float_timeout = fingerprint_model_request(
+ "m", make_messages(), {"timeout": 30.0}, ModelRequestParameters()
+ )
+ httpx_timeout = fingerprint_model_request(
+ "m", make_messages(), {"timeout": httpx.Timeout(5.0)},
ModelRequestParameters()
+ )
+
+ assert no_timeout == float_timeout == httpx_timeout
+
+ def test_content_settings_still_count_when_timeout_present(self):
+ """Stripping timeout must not drop content settings sharing the
dict."""
+ low = fingerprint_model_request(
+ "m",
+ make_messages(),
+ {"temperature": 0.2, "timeout": httpx.Timeout(1.0)},
+ ModelRequestParameters(),
+ )
+ high = fingerprint_model_request(
+ "m",
+ make_messages(),
+ {"temperature": 0.9, "timeout": httpx.Timeout(1.0)},
+ ModelRequestParameters(),
+ )
+
+ assert low is not None
+ assert high is not None
+ assert low != high
+
+
+class TestToolCallFingerprint:
+ def test_stable_for_identical_call(self):
+ assert fingerprint_tool_call("t", {"a": 1}, "id1") ==
fingerprint_tool_call("t", {"a": 1}, "id1")
+
+ def test_changes_with_name(self):
+ assert fingerprint_tool_call("a", {}, "id1") !=
fingerprint_tool_call("b", {}, "id1")
+
+ def test_changes_with_args(self):
+ assert fingerprint_tool_call("t", {"a": 1}, "id1") !=
fingerprint_tool_call("t", {"a": 2}, "id1")
+
+ def test_changes_with_tool_call_id(self):
+ assert fingerprint_tool_call("t", {}, "id1") !=
fingerprint_tool_call("t", {}, "id2")
+
+ def test_arg_order_does_not_matter(self):
+ assert fingerprint_tool_call("t", {"a": 1, "b": 2}, "id1") ==
fingerprint_tool_call(
+ "t", {"b": 2, "a": 1}, "id1"
+ )
diff --git
a/providers/common/ai/tests/unit/common/ai/durable/test_replay_verification.py
b/providers/common/ai/tests/unit/common/ai/durable/test_replay_verification.py
new file mode 100644
index 00000000000..ad9c244b0c6
--- /dev/null
+++
b/providers/common/ai/tests/unit/common/ai/durable/test_replay_verification.py
@@ -0,0 +1,138 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+End-to-end replay verification through a real pydantic-ai agent loop.
+
+Simulates the retry scenario durable execution exists for: attempt 1 fails
+partway with steps cached, attempt 2 starts a fresh counter against the same
+cache file. Replay must happen when the agent is unchanged and must NOT happen
+when the agent changed between attempts (the positional-keying staleness bug).
+"""
+
+from __future__ import annotations
+
+from unittest.mock import patch
+
+import pytest
+from pydantic_ai import Agent
+from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart,
ToolCallPart
+from pydantic_ai.models.function import AgentInfo, FunctionModel
+from pydantic_ai.toolsets import FunctionToolset
+
+from airflow.providers.common.ai.durable.caching_model import CachingModel
+from airflow.providers.common.ai.durable.caching_toolset import CachingToolset
+from airflow.providers.common.ai.durable.step_counter import DurableStepCounter
+from airflow.providers.common.ai.durable.storage import DurableStorage
+from airflow.sdk import ObjectStoragePath
+
+
[email protected]
+def storage(tmp_path):
+ with patch("airflow.providers.common.ai.durable.storage._get_base_path")
as mock_base:
+ mock_base.return_value =
ObjectStoragePath(f"file://{tmp_path.as_posix()}")
+ yield DurableStorage(dag_id="d", task_id="t", run_id="r")
+
+
+class AgentHarness:
+ """A scripted two-step agent: one tool call, then a final answer."""
+
+ def __init__(self, storage: DurableStorage, *, system_prompt: str, rate:
float, fail: bool):
+ self.live_model_calls = 0
+ self.live_tool_calls = 0
+ self.counter = DurableStepCounter()
+
+ def model_fn(messages: list[ModelMessage], info: AgentInfo) ->
ModelResponse:
+ self.live_model_calls += 1
+ if not any(isinstance(m, ModelResponse) for m in messages):
+ return
ModelResponse(parts=[ToolCallPart(tool_name="get_fx_rate", args={})])
+ if fail:
+ raise RuntimeError("simulated transient failure")
+ returned = next(p.content for m in messages for p in m.parts if
p.part_kind == "tool-return")
+ return ModelResponse(parts=[TextPart(content=f"rate={returned}")])
+
+ def get_fx_rate() -> float:
+ """Return the USD to EUR exchange rate."""
+ self.live_tool_calls += 1
+ return rate
+
+ # Constructor form: the @toolset.tool decorator is typed for
+ # RunContext-taking functions only.
+ toolset = FunctionToolset(tools=[get_fx_rate])
+
+ self.agent = Agent(
+ model=CachingModel(FunctionModel(model_fn), storage=storage,
counter=self.counter),
+ system_prompt=system_prompt,
+ toolsets=[CachingToolset(wrapped=toolset, storage=storage,
counter=self.counter)],
+ )
+
+ async def run(self) -> str:
+ result = await self.agent.run("Convert 100 USD to EUR")
+ return result.output
+
+
+async def failed_first_attempt(storage: DurableStorage, system_prompt: str =
"currency bot") -> None:
+ """Attempt 1: tool result 0.42 gets cached, then the final model call
fails."""
+ harness = AgentHarness(storage, system_prompt=system_prompt, rate=0.42,
fail=True)
+ with pytest.raises(RuntimeError, match="simulated transient failure"):
+ await harness.run()
+ storage._cache = None # retry runs in a new process: in-memory cache is
cold
+
+
+class TestUnchangedRetryReplays:
+ @pytest.mark.asyncio
+ async def test_completed_steps_replay_only_failed_step_reruns(self,
storage):
+ await failed_first_attempt(storage)
+
+ retry = AgentHarness(storage, system_prompt="currency bot", rate=0.42,
fail=False)
+ output = await retry.run()
+
+ assert output == "rate=0.42"
+ assert retry.counter.replayed_model == 1
+ assert retry.counter.replayed_tool == 1
+ assert retry.live_tool_calls == 0
+ assert retry.live_model_calls == 1 # only the step that failed on
attempt 1
+
+
+class TestChangedAgentDoesNotReplayStaleSteps:
+ @pytest.mark.asyncio
+ async def test_changed_system_prompt_invalidates_replay(self, storage):
+ """Regression test for positional-keying staleness: a prompt tweak
between
+ attempts must re-run the conversation, not replay the old one."""
+ await failed_first_attempt(storage)
+
+ retry = AgentHarness(storage, system_prompt="careful currency bot",
rate=0.99, fail=False)
+ output = await retry.run()
+
+ # The fixed tool ran in the new conversation; nothing stale was
replayed.
+ assert output == "rate=0.99"
+ assert retry.counter.replayed_model == 0
+ assert retry.counter.replayed_tool == 0
+ assert retry.live_tool_calls == 1
+
+ @pytest.mark.asyncio
+ async def test_divergence_chains_to_downstream_tool_steps(self, storage):
+ """A live model call mints fresh tool_call_ids, so cached tool results
+ recorded under the old conversation cannot be cross-wired into the new
one."""
+ await failed_first_attempt(storage)
+
+ retry = AgentHarness(storage, system_prompt="careful currency bot",
rate=0.42, fail=False)
+ await retry.run()
+
+ # Same tool name and args as attempt 1, but the conversation diverged
+ # at step 0 -- the tool must run live, not replay.
+ assert retry.live_tool_calls == 1
+ assert retry.counter.replayed_tool == 0
diff --git a/providers/common/ai/tests/unit/common/ai/durable/test_storage.py
b/providers/common/ai/tests/unit/common/ai/durable/test_storage.py
index 507fd3126ac..03f85be30e1 100644
--- a/providers/common/ai/tests/unit/common/ai/durable/test_storage.py
+++ b/providers/common/ai/tests/unit/common/ai/durable/test_storage.py
@@ -16,13 +16,16 @@
# under the License.
from __future__ import annotations
+import json
from unittest.mock import patch
import pytest
from pydantic_ai.messages import (
+ ModelMessagesTypeAdapter,
ModelResponse,
TextPart,
)
+from pydantic_ai.usage import RequestUsage
from airflow.providers.common.ai.durable.storage import DurableStorage
from airflow.sdk import ObjectStoragePath
@@ -61,47 +64,147 @@ class TestDurableStorageInit:
class TestSaveLoadModelResponse:
def test_save_and_load_roundtrips(self, storage, sample_response):
- storage.save_model_response("model_step_0", sample_response)
+ storage.save_model_response("model_step_0", sample_response,
fingerprint="fp_abc")
# Reset in-memory cache to force read from file
storage._cache = None
- loaded = storage.load_model_response("model_step_0")
+ loaded, fingerprint = storage.load_model_response("model_step_0")
assert loaded is not None
assert loaded.parts[0].content == "Hello!"
+ assert fingerprint == "fp_abc"
def test_load_returns_none_when_no_cache(self, storage):
- assert storage.load_model_response("model_step_0") is None
+ assert storage.load_model_response("model_step_0") == (None, None)
+
+ def test_metadata_carrying_response_roundtrips_byte_identical(self,
storage):
+ """Multi-step replay relies on cached responses round-tripping
byte-identically:
+ a later step's fingerprint includes earlier responses in history,
metadata
+ (usage, provider_response_id, finish_reason) and all. If a store/load
cycle
+ altered any of it, every multi-step replay would mismatch and re-run.
Pin it."""
+ resp = ModelResponse(
+ parts=[TextPart(content="answer")],
+ usage=RequestUsage(input_tokens=11, output_tokens=22),
+ model_name="gpt-x",
+ provider_response_id="resp_xyz",
+ finish_reason="stop",
+ )
+ before = ModelMessagesTypeAdapter.dump_python([resp], mode="json")
+
+ storage.save_model_response("model_step_0", resp, fingerprint="fp")
+ storage._cache = None
+ loaded, _ = storage.load_model_response("model_step_0")
+
+ after = ModelMessagesTypeAdapter.dump_python([loaded], mode="json")
+ assert after == before
+
+ def test_stored_entry_is_single_encoded(self, storage, sample_response):
+ """The response payload is stored as native JSON objects, not a nested
+ JSON string -- the whole cache is encoded exactly once by
``_save_cache``."""
+ storage.save_model_response("model_step_0", sample_response,
fingerprint="fp")
+
+ on_disk = json.loads(storage._get_path().read_text())
+ entry = on_disk["model_step_0"]
+
+ assert isinstance(entry, dict) # not a re-encoded JSON string
+ assert isinstance(entry["data"], list) # not a nested JSON string
+ assert entry["fingerprint"] == "fp"
+
+ def test_legacy_entry_without_fingerprint_loads(self, storage,
sample_response):
+ """Entries written before fingerprinting (raw adapter JSON) load with
a None fingerprint."""
+ cache = storage._load_cache()
+ cache["model_step_0"] =
ModelMessagesTypeAdapter.dump_json([sample_response]).decode()
+ storage._save_cache()
+ storage._cache = None
+
+ loaded, fingerprint = storage.load_model_response("model_step_0")
+
+ assert loaded is not None
+ assert loaded.parts[0].content == "Hello!"
+ assert fingerprint is None
class TestSaveLoadToolResult:
def test_save_and_load_roundtrips(self, storage):
- storage.save_tool_result("tool_step_0", {"rows": [1, 2, 3]})
+ storage.save_tool_result("tool_step_0", {"rows": [1, 2, 3]},
fingerprint="fp")
storage._cache = None
- found, value = storage.load_tool_result("tool_step_0")
+ found, value, fingerprint = storage.load_tool_result("tool_step_0")
assert found is True
assert value == {"rows": [1, 2, 3]}
+ def test_fingerprint_roundtrips(self, storage):
+ storage.save_tool_result("tool_step_0", "result",
fingerprint="fp_tool")
+
+ storage._cache = None
+ found, value, fingerprint = storage.load_tool_result("tool_step_0")
+
+ assert found is True
+ assert fingerprint == "fp_tool"
+
+ def test_legacy_entry_without_fingerprint_loads(self, storage):
+ """Entries written before fingerprinting load with a None
fingerprint."""
+ cache = storage._load_cache()
+ cache["tool_step_0"] = json.dumps({"__durable_cached__": True,
"value": "old"})
+ storage._save_cache()
+ storage._cache = None
+
+ found, value, fingerprint = storage.load_tool_result("tool_step_0")
+
+ assert found is True
+ assert value == "old"
+ assert fingerprint is None
+
def test_load_returns_false_when_no_cache(self, storage):
- found, value = storage.load_tool_result("tool_step_0")
+ found, value, fingerprint = storage.load_tool_result("tool_step_0")
assert found is False
assert value is None
+ assert fingerprint is None
def test_none_result_roundtrips(self, storage):
- storage.save_tool_result("tool_step_0", None)
+ storage.save_tool_result("tool_step_0", None, fingerprint="fp")
storage._cache = None
- found, value = storage.load_tool_result("tool_step_0")
+ found, value, fingerprint = storage.load_tool_result("tool_step_0")
assert found is True
assert value is None
+ def test_circular_reference_result_is_skipped_not_raised(self, storage):
+ """A circular reference raises ValueError in json.dumps; it must skip
the
+ entry with a warning, not crash the (otherwise successful) tool
step."""
+ circular: dict = {}
+ circular["self"] = circular
+
+ storage.save_tool_result("tool_step_0", circular, fingerprint="fp") #
must not raise
+
+ found, _, _ = storage.load_tool_result("tool_step_0")
+ assert found is False
+
+
+class TestMalformedEntries:
+ def test_empty_data_list_degrades_to_miss(self, storage):
+ """A torn entry whose data list is empty loads as a miss, not an
IndexError."""
+ cache = storage._load_cache()
+ cache["model_step_0"] = {"fingerprint": "fp", "data": []}
+ storage._save_cache()
+ storage._cache = None
+
+ assert storage.load_model_response("model_step_0") == (None, None)
+
+ def test_entry_missing_data_key_degrades_to_miss(self, storage):
+ cache = storage._load_cache()
+ cache["model_step_0"] = {"fingerprint": "fp"}
+ storage._save_cache()
+ storage._cache = None
+
+ assert storage.load_model_response("model_step_0") == (None, None)
+
class TestCleanup:
def test_cleanup_deletes_file(self, storage, sample_response):
- storage.save_model_response("model_step_0", sample_response)
+ storage.save_model_response("model_step_0", sample_response,
fingerprint="fp")
path = storage._get_path()
assert path.exists()
@@ -114,9 +217,9 @@ class TestCleanup:
class TestInMemoryCaching:
def test_multiple_saves_write_single_file(self, storage, sample_response):
- storage.save_model_response("model_step_0", sample_response)
- storage.save_tool_result("tool_step_1", "result")
- storage.save_model_response("model_step_2", sample_response)
+ storage.save_model_response("model_step_0", sample_response,
fingerprint="fp")
+ storage.save_tool_result("tool_step_1", "result", fingerprint="fp")
+ storage.save_model_response("model_step_2", sample_response,
fingerprint="fp")
assert "model_step_0" in storage._cache
assert "tool_step_1" in storage._cache
@@ -124,16 +227,16 @@ class TestInMemoryCaching:
def test_cache_survives_reload(self, storage, sample_response):
"""Simulate retry: save cache, reset in-memory, reload from file."""
- storage.save_model_response("model_step_0", sample_response)
- storage.save_tool_result("tool_step_1", "tool result")
+ storage.save_model_response("model_step_0", sample_response,
fingerprint="fp")
+ storage.save_tool_result("tool_step_1", "tool result",
fingerprint="fp")
# Simulate new DurableStorage instance (as on retry)
storage._cache = None
- loaded_response = storage.load_model_response("model_step_0")
+ loaded_response, _ = storage.load_model_response("model_step_0")
assert loaded_response is not None
assert loaded_response.parts[0].content == "Hello!"
- found, value = storage.load_tool_result("tool_step_1")
+ found, value, _ = storage.load_tool_result("tool_step_1")
assert found is True
assert value == "tool result"