This is an automated email from the ASF dual-hosted git repository.

kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 531bfab9dee Add `message_history` to `AgentOperator` for multi-turn 
agent sessions (#68648)
531bfab9dee is described below

commit 531bfab9dee3ccdd0cea1e679f6de23b1b88b3b6
Author: Kaxil Naik <[email protected]>
AuthorDate: Fri Jun 19 00:57:58 2026 +0100

    Add `message_history` to `AgentOperator` for multi-turn agent sessions 
(#68648)
    
    AgentOperator and @task.agent ran a fresh single-turn conversation every 
time. Add an opt-in message_history parameter that seeds the run with prior 
turns and pushes the post-run transcript to XCom (key 'message_history') so the 
next run can resume. Default None keeps single-turn behavior unchanged. Storing 
the transcript under a session key stays the DAG's responsibility.
---
 providers/common/ai/docs/operators/agent.rst       |  48 +++++++
 .../common/ai/example_dags/example_agent.py        |  56 +++++++-
 .../airflow/providers/common/ai/operators/agent.py |  69 +++++++++-
 .../tests/unit/common/ai/operators/test_agent.py   | 149 ++++++++++++++++++++-
 4 files changed, 318 insertions(+), 4 deletions(-)

diff --git a/providers/common/ai/docs/operators/agent.rst 
b/providers/common/ai/docs/operators/agent.rst
index a79e5110048..b3805aa34b7 100644
--- a/providers/common/ai/docs/operators/agent.rst
+++ b/providers/common/ai/docs/operators/agent.rst
@@ -156,6 +156,49 @@ tasks can consume it.
     :end-before: [END howto_agent_chain]
 
 
+Multi-turn Sessions
+-------------------
+
+By default each agent run is a cold, single-turn conversation. To carry a
+conversation across runs -- a chat or iterative agent where "and the third 
one?"
+must resolve against an earlier answer -- pass ``message_history``.
+
+When ``message_history`` is set, the operator seeds the run with those prior
+turns and, after the run, pushes the full updated transcript
+(``result.all_messages()``) to XCom under the key ``message_history``. The next
+run reads it back to resume the conversation. ``None`` (the default) keeps the
+single-turn behavior unchanged.
+
+The operator does **not** decide *where* a session is stored -- that keying is
+deployment-specific. The pattern is three tasks: load the prior transcript for
+the session, run the agent, store the updated transcript. The example keys a
+JSON file in object storage by ``session_id`` (use ``s3://`` / ``gs://`` in a
+deployment); the first run starts from an empty ``"[]"``.
+
+.. exampleinclude:: 
/../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+    :language: python
+    :start-after: [START howto_agent_session]
+    :end-before: [END howto_agent_session]
+
+``message_history`` accepts a list of pydantic-ai ``ModelMessage`` objects or
+their JSON form (``str`` / ``bytes``), so the value emitted to XCom feeds
+straight back in on the next run. When pulling it via a template, pass
+``default='[]'`` (as above) so the first run -- which has no XCom yet -- 
starts a
+fresh session instead of trying to parse the string ``"None"``.
+
+The transcript is **cumulative**: each turn appends to it, so it grows for the
+life of the session. For long sessions, configure an object-storage XCom 
backend
+or trim older turns before the next run rather than feeding the whole history
+back unbounded.
+
+.. note::
+
+    ``message_history`` cannot be combined with ``enable_hitl_review`` -- the
+    operator raises at construction. The post-review (human-approved) 
transcript
+    is not recoverable today, so emitting the pre-review transcript would
+    silently drop the reviewed turns.
+
+
 Durable Execution
 -----------------
 
@@ -406,6 +449,11 @@ Parameters
 - ``code_mode``: When ``True``, wraps the agent's tools in a single 
``run_code``
   tool that the model drives by writing Python, executed in the Monty sandbox.
   Requires the ``code-mode`` extra. Default ``False``. See :ref:`code-mode`.
+- ``message_history``: Prior conversation to seed a multi-turn session, as a 
list
+  of pydantic-ai ``ModelMessage`` objects or their JSON form (``str`` / 
``bytes``).
+  When set, the post-run transcript is pushed to XCom under the key
+  ``message_history`` for the next run to resume. Default ``None`` 
(single-turn).
+  See `Multi-turn Sessions`_.
 
 
 Logging
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
index 6fced224c92..787b0d6dce2 100644
--- 
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+++ 
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
@@ -24,7 +24,7 @@ from pydantic import BaseModel
 
 from airflow.providers.common.ai.operators.agent import AgentOperator
 from airflow.providers.common.ai.toolsets.hook import HookToolset
-from airflow.providers.common.compat.sdk import dag, task
+from airflow.providers.common.compat.sdk import ObjectStoragePath, dag, task
 
 try:
     from airflow.providers.common.ai.toolsets.sql import SQLToolset
@@ -247,3 +247,57 @@ def example_agent_operator_code_mode():
 # [END howto_operator_agent_code_mode]
 
 example_agent_operator_code_mode()
+
+
+# ---------------------------------------------------------------------------
+# 8. Multi-turn session — resume a conversation across DAG runs
+# ---------------------------------------------------------------------------
+
+
+# [START howto_agent_session]
+@dag(tags=["example"], params={"session_id": "demo-session"})
+def example_agent_session():
+    """Resume a conversation across runs via ``message_history``.
+
+    The agent step seeds itself with the prior transcript and re-emits the
+    updated transcript to XCom (key ``message_history``). Loading and storing
+    that transcript under a session key is the DAG's job -- here, a JSON file 
in
+    object storage keyed by ``session_id``. Swap the path for ``s3://`` /
+    ``gs://`` in a deployment.
+    """
+    sessions_root = ObjectStoragePath("file:///tmp/airflow_agent_sessions")
+
+    @task
+    def load_history(session_id: str) -> str:
+        path = sessions_root / f"{session_id}.json"
+        # First turn: no file yet -> start a fresh session (empty transcript).
+        return path.read_text() if path.exists() else "[]"
+
+    @task.agent(
+        llm_conn_id="pydanticai_default",
+        system_prompt="You are a helpful assistant. Use the earlier turns for 
context.",
+        # The XComArg both wires the dependency and resolves to the JSON 
transcript.
+        message_history=load_history("{{ params.session_id }}"),
+    )
+    def ask(question: str) -> str:
+        return question
+
+    @task
+    def save_history(session_id: str, transcript: str) -> None:
+        # Local/fsspec object storage does not auto-create parent dirs on 
write.
+        sessions_root.mkdir(parents=True, exist_ok=True)
+        (sessions_root / f"{session_id}.json").write_text(transcript)
+
+    answer = ask("And what did I ask you a moment ago?")
+    saved = save_history(
+        "{{ params.session_id }}",
+        # The agent step pushes the post-run transcript under this XCom key.
+        "{{ ti.xcom_pull(task_ids='ask', key='message_history') }}",
+    )
+    # save runs after the agent so the pulled transcript is the fresh one.
+    answer >> saved
+
+
+# [END howto_agent_session]
+
+example_agent_session()
diff --git 
a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py 
b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
index bda06ea1f56..56c9ec5bbb6 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
@@ -48,6 +48,7 @@ except ImportError:  # pragma: no cover - cores before the 
worker-side registrat
 
 if TYPE_CHECKING:
     from pydantic_ai import Agent
+    from pydantic_ai.messages import ModelMessage
     from pydantic_ai.toolsets.abstract import AbstractToolset
     from pydantic_ai.usage import UsageLimits
 
@@ -166,6 +167,22 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
         Cannot be combined with ``durable=True`` (durable replay assumes a
         stable per-step call order that code mode does not guarantee).
         Default ``False``.
+    :param message_history: Prior conversation to seed the run with, for
+        multi-turn sessions that span task runs. Accepts a ``list`` of
+        pydantic-ai ``ModelMessage`` objects, or their JSON form as ``str`` /
+        ``bytes`` -- e.g.
+        ``"{{ ti.xcom_pull(task_ids='ask', key='message_history', 
default='[]') }}"``
+        (pass ``default='[]'`` so the first run, with no XCom yet, starts a 
fresh
+        session instead of failing to parse the string ``"None"``). ``None``
+        (default) is a single-turn run -- no behavior change. When set (an 
empty
+        ``[]`` / ``""`` starts a fresh session), the full transcript after the 
run
+        -- ``result.all_messages()`` -- is pushed to XCom under the key
+        ``message_history`` so the next run can resume. Persisting that 
transcript
+        under a session key (e.g. in object storage) is the DAG's 
responsibility.
+        The transcript is cumulative and grows each turn; for long sessions 
use an
+        object-storage XCom backend or trim old turns. Not supported together 
with
+        ``enable_hitl_review`` (raises) -- the post-review transcript is not 
yet
+        recoverable.
 
     **HITL Review parameters** (requires the ``hitl_review`` plugin):
 
@@ -199,6 +216,7 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
         "model_id",
         "system_prompt",
         "agent_params",
+        "message_history",
     )
 
     operator_extra_links = (HITLReviewLink(),)
@@ -217,6 +235,7 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
         usage_limits: UsageLimits | None = None,
         durable: bool = False,
         code_mode: bool = False,
+        message_history: list[ModelMessage] | str | bytes | None = None,
         # Agent feedback parameters
         enable_hitl_review: bool = False,
         max_hitl_iterations: int = 5,
@@ -240,6 +259,7 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
         self.enable_tool_logging = enable_tool_logging
         self.agent_params = agent_params or {}
         self.usage_limits = usage_limits
+        self.message_history = message_history
 
         self.durable = durable
         self.code_mode = code_mode
@@ -256,6 +276,13 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
             # replay. Reject the combination rather than silently 
mis-replaying.
             raise ValueError("durable=True and code_mode=True cannot be used 
together.")
 
+        if message_history is not None and enable_hitl_review:
+            # The post-review transcript is not recoverable today 
(run_hitl_review
+            # returns only the final string), so emitting the pre-review 
transcript
+            # would silently drop the human-approved turns. Block until HITL 
can
+            # surface the final message history.
+            raise ValueError("message_history and enable_hitl_review=True 
cannot be used together.")
+
         self.enable_hitl_review = enable_hitl_review
         self.max_hitl_iterations = max_hitl_iterations
         self.hitl_timeout = hitl_timeout
@@ -331,6 +358,11 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
 
         agent = self._build_agent()
 
+        run_kwargs: dict[str, Any] = {"usage_limits": self.usage_limits}
+        history = self._resolve_message_history()
+        if history is not None:
+            run_kwargs["message_history"] = history
+
         storage = self._durable_storage
         counter = self._durable_counter
         if self.durable and storage is not None and counter is not None:
@@ -343,9 +375,9 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
             resolved_model = infer_model(agent.model)
             caching_model = CachingModel(resolved_model, storage=storage, 
counter=counter)
             with agent.override(model=caching_model):
-                result = agent.run_sync(self.prompt, 
usage_limits=self.usage_limits)
+                result = agent.run_sync(self.prompt, **run_kwargs)
         else:
-            result = agent.run_sync(self.prompt, 
usage_limits=self.usage_limits)
+            result = agent.run_sync(self.prompt, **run_kwargs)
 
         log_run_summary(self.log, result)
 
@@ -368,6 +400,9 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
         if self._durable_storage is not None:
             self._durable_storage.cleanup()
 
+        if self.message_history is not None:
+            self._emit_message_history(context, result)
+
         output = result.output
 
         if self.enable_hitl_review:
@@ -391,6 +426,36 @@ class AgentOperator(BaseOperator, HITLReviewMixin):
             output = output.model_dump()
         return output
 
+    def _resolve_message_history(self) -> list[ModelMessage] | None:
+        """
+        Deserialize :attr:`message_history` into a list of pydantic-ai 
messages.
+
+        ``None`` means single-turn (no history passed to the run). A ``str`` /
+        ``bytes`` value is parsed as the JSON the operator emits to XCom; a 
list
+        (of ``ModelMessage`` objects or their dict form) is validated as-is.
+        """
+        raw = self.message_history
+        if raw is None:
+            return None
+        if isinstance(raw, (str, bytes)) and not raw.strip():
+            # A template that renders to empty (no prior XCom) starts a fresh 
session.
+            return []
+        # pydantic-ai is imported lazily here to match this module's pattern of
+        # keeping pydantic-ai out of DAG-parse-time imports.
+        from pydantic_ai.messages import ModelMessagesTypeAdapter
+
+        if isinstance(raw, (str, bytes)):
+            return ModelMessagesTypeAdapter.validate_json(raw)
+        return ModelMessagesTypeAdapter.validate_python(raw)
+
+    def _emit_message_history(self, context: Context, result: Any) -> None:
+        """Push the full post-run transcript to XCom for the next turn to 
resume."""
+        # Lazy import: see _resolve_message_history.
+        from pydantic_ai.messages import ModelMessagesTypeAdapter
+
+        transcript = 
ModelMessagesTypeAdapter.dump_json(result.all_messages()).decode()
+        context["task_instance"].xcom_push(key="message_history", 
value=transcript)
+
     def regenerate_with_feedback(self, *, feedback: str, message_history: Any) 
-> tuple[str, Any]:
         """Re-run the agent with *feedback* appended to the conversation 
history."""
         agent = self._build_agent()
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py 
b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
index a9f017b94ee..1288dbbe652 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
@@ -22,6 +22,13 @@ from unittest.mock import MagicMock, patch
 
 import pytest
 from pydantic import BaseModel
+from pydantic_ai.messages import (
+    ModelMessagesTypeAdapter,
+    ModelRequest,
+    ModelResponse,
+    TextPart,
+    UserPromptPart,
+)
 from pydantic_ai.usage import UsageLimits
 
 from airflow.providers.common.ai.operators.agent import AgentOperator, 
HITLReviewLink, _build_code_mode
@@ -90,7 +97,14 @@ class TestAgentOperatorValidation:
 
 class TestAgentOperatorTemplateFields:
     def test_template_fields(self):
-        expected = {"prompt", "llm_conn_id", "model_id", "system_prompt", 
"agent_params"}
+        expected = {
+            "prompt",
+            "llm_conn_id",
+            "model_id",
+            "system_prompt",
+            "agent_params",
+            "message_history",
+        }
         assert set(AgentOperator.template_fields) == expected
 
 
@@ -617,3 +631,136 @@ class TestAgentOperatorMultimodalPromptGuard:
             op.execute(context=MagicMock())
 
         mock_agent.run_sync.assert_not_called()
+
+
+def _sample_history():
+    """A minimal two-message pydantic-ai conversation for round-trip tests."""
+    return [
+        ModelRequest(parts=[UserPromptPart(content="first question")]),
+        ModelResponse(parts=[TextPart(content="first answer")]),
+    ]
+
+
+# The accepted input forms for ``message_history``, computed once at 
collection time.
+_SAMPLE_HISTORY_JSON = 
ModelMessagesTypeAdapter.dump_json(_sample_history()).decode()
+_SAMPLE_HISTORY_DICTS = 
ModelMessagesTypeAdapter.dump_python(_sample_history(), mode="json")
+
+
+class TestAgentOperatorMessageHistory:
+    """Multi-turn session support: seed run_sync with prior history, emit the 
transcript."""
+
+    @pytest.mark.parametrize(
+        ("raw", "expected_len"),
+        [
+            pytest.param([], 0, id="empty-list"),
+            pytest.param("", 0, id="empty-str"),
+            pytest.param("   ", 0, id="blank-str"),
+            pytest.param(_SAMPLE_HISTORY_JSON, 2, id="json-str"),
+            pytest.param(_SAMPLE_HISTORY_DICTS, 2, id="list-of-dicts"),
+            pytest.param(_sample_history(), 2, id="list-of-objects"),
+        ],
+    )
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_message_history_seeds_run_sync(self, mock_hook_cls, raw, 
expected_len):
+        """Every accepted input form is deserialized and passed to run_sync; 
blank/empty start fresh."""
+        mock_agent = _make_mock_agent("ok")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
+
+        op = AgentOperator(task_id="t", prompt="run", llm_conn_id="c", 
message_history=raw)
+        op.execute(context=MagicMock())
+
+        passed = mock_agent.run_sync.call_args.kwargs["message_history"]
+        assert len(passed) == expected_len
+        assert all(isinstance(m, (ModelRequest, ModelResponse)) for m in 
passed)
+
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_none_is_single_turn_no_history_no_emit(self, mock_hook_cls):
+        """Default message_history=None passes no history and pushes no 
transcript XCom."""
+        mock_agent = _make_mock_agent("ok")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
+
+        op = AgentOperator(task_id="t", prompt="run", llm_conn_id="c")
+        context = MagicMock()
+        op.execute(context=context)
+
+        assert "message_history" not in mock_agent.run_sync.call_args.kwargs
+        context["task_instance"].xcom_push.assert_not_called()
+
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_transcript_emitted_to_xcom_when_history_set(self, mock_hook_cls):
+        """When message_history is set, the post-run transcript is pushed to 
XCom and round-trips."""
+        mock_agent = _make_mock_agent("ok")
+        mock_agent.run_sync.return_value.all_messages.return_value = 
_sample_history()
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
+
+        op = AgentOperator(task_id="t", prompt="run", llm_conn_id="c", 
message_history=[])
+        context = MagicMock()
+        op.execute(context=context)
+
+        ti = context["task_instance"]
+        ti.xcom_push.assert_called_once()
+        push_kwargs = ti.xcom_push.call_args.kwargs
+        assert push_kwargs["key"] == "message_history"
+        restored = ModelMessagesTypeAdapter.validate_json(push_kwargs["value"])
+        assert len(restored) == 2
+
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_usage_limits_still_forwarded_with_history(self, mock_hook_cls):
+        """Adding message_history does not drop usage_limits from the run_sync 
call."""
+        mock_agent = _make_mock_agent("ok")
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
+
+        limits = UsageLimits(request_limit=2)
+        op = AgentOperator(
+            task_id="t", prompt="run", llm_conn_id="c", usage_limits=limits, 
message_history=[]
+        )
+        op.execute(context=MagicMock())
+
+        kwargs = mock_agent.run_sync.call_args.kwargs
+        assert kwargs["usage_limits"] is limits
+        assert kwargs["message_history"] == []
+
+    def test_message_history_with_hitl_review_raises(self):
+        """message_history cannot be combined with HITL review (post-review 
transcript is lost)."""
+        with pytest.raises(ValueError, match="message_history and 
enable_hitl_review"):
+            AgentOperator(
+                task_id="t",
+                prompt="run",
+                llm_conn_id="c",
+                message_history=[],
+                enable_hitl_review=True,
+            )
+
+    @patch("pydantic_ai.models.wrapper.infer_model", side_effect=lambda m: m)
+    @patch("pydantic_ai.models.infer_model", autospec=True)
+    @patch("airflow.providers.common.ai.durable.storage._get_base_path")
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_durable_path_also_seeds_message_history(
+        self, mock_hook_cls, mock_base_path, mock_infer, _, tmp_path
+    ):
+        """The durable branch forwards message_history into the cached run 
too."""
+        from airflow.sdk import ObjectStoragePath
+
+        mock_base_path.return_value = 
ObjectStoragePath(f"file://{tmp_path.as_posix()}")
+
+        mock_agent = MagicMock(spec=["run_sync", "model", "override"])
+        mock_agent.run_sync.return_value = _make_mock_run_result("ok")
+        mock_agent.model = "test-model"
+        mock_agent.override.return_value.__enter__ = 
MagicMock(return_value=None)
+        mock_agent.override.return_value.__exit__ = 
MagicMock(return_value=False)
+        mock_hook_cls.get_hook.return_value.create_agent.return_value = 
mock_agent
+        mock_infer.return_value = MagicMock()
+
+        context = MagicMock()
+        context.__getitem__ = MagicMock(
+            return_value=MagicMock(dag_id="d", task_id="t", run_id="r", 
map_index=-1)
+        )
+
+        history_json = 
ModelMessagesTypeAdapter.dump_json(_sample_history()).decode()
+        op = AgentOperator(
+            task_id="test", prompt="test", llm_conn_id="my_llm", durable=True, 
message_history=history_json
+        )
+        op.execute(context=context)
+
+        passed = mock_agent.run_sync.call_args.kwargs["message_history"]
+        assert len(passed) == 2

Reply via email to