This is an automated email from the ASF dual-hosted git repository.
damccorm pushed a commit to branch release-2.74
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/release-2.74 by this push:
new c3727149272 Make sure session creation happens before starting agent
(#38477) (#38565)
c3727149272 is described below
commit c372714927275ee1b696942534cb6836fa1fe885
Author: Danny McCormick <[email protected]>
AuthorDate: Thu May 21 10:20:31 2026 -0400
Make sure session creation happens before starting agent (#38477) (#38565)
* Make sure session creation happens before starting agent
* fix var
* Fix up
* Fix text parsing
* yapf
---
.../ml/inference/agent_development_kit.py | 36 +++++++++++---------
.../ml/inference/agent_development_kit_test.py | 38 ++++++++++++++--------
2 files changed, 45 insertions(+), 29 deletions(-)
diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py
b/sdks/python/apache_beam/ml/inference/agent_development_kit.py
index 1130598f06f..386955b0dfa 100644
--- a/sdks/python/apache_beam/ml/inference/agent_development_kit.py
+++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py
@@ -229,17 +229,6 @@ class ADKAgentModelHandler(ModelHandler[str |
genai_Content,
for element in batch:
session_id: str = inference_args.get("session_id", str(uuid.uuid4()))
- # Ensure a session exists for this invocation
- try:
- model.session_service.create_session(
- app_name=self._app_name,
- user_id=user_id,
- session_id=session_id,
- )
- except sessions.SessionExistsError:
- # It's okay if the session already exists for shared session IDs.
- pass
-
# Wrap plain strings in a Content object
if isinstance(element, str):
# pyrefly: ignore[bad-instantiation]
@@ -249,7 +238,8 @@ class ADKAgentModelHandler(ModelHandler[str | genai_Content,
message = element
agent_invocations.append(
- self._invoke_agent(model, user_id, session_id, message))
+ self._invoke_agent(
+ model, user_id, session_id, self._app_name, message))
elements_with_sessions.append(element)
# Run all agent invocations concurrently
@@ -274,6 +264,7 @@ class ADKAgentModelHandler(ModelHandler[str | genai_Content,
runner: "Runner",
user_id: str,
session_id: str,
+ app_name: str,
message: genai_Content,
) -> Optional[str]:
"""Drives the ADK event loop and returns the final response text.
@@ -288,15 +279,30 @@ class ADKAgentModelHandler(ModelHandler[str |
genai_Content,
The text of the agent's final response, or ``None`` if the agent
produced no final text response.
"""
+ # Check for your specific session ID
+ try:
+ # Attempt to get the specific session
+ await runner.session_service.get_session(session_id)
+ except Exception as e:
+ await runner.session_service.create_session(
+ app_name=app_name,
+ user_id=user_id,
+ session_id=session_id,
+ )
+
async for event in runner.run_async(
user_id=user_id,
session_id=session_id,
new_message=message,
):
if event.is_final_response():
- if event.content:
- return event.content.text
- return None
+ if event.content and event.content.parts:
+ return "".join([p.text for p in event.content.parts])
+ raise ValueError(
+ f"Agent {runner.agent.name} did not return a response, "
+ f"final event: {event}")
+
+ raise ValueError(f"Agent {runner.agent.name} did not return a response")
def get_metrics_namespace(self) -> str:
return "ADKAgentModelHandler"
diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py
b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py
index 6d59bceb9d3..6c8b5c5b351 100644
--- a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py
+++ b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py
@@ -41,9 +41,11 @@ def _make_mock_runner(
final_text: str = "Hello from agent",
) -> mock.MagicMock:
"""Returns a mock Runner whose run_async yields one final-response event."""
- # Build a mock event that looks like a final response
+ part = mock.MagicMock()
+ part.text = final_text
+
content = mock.MagicMock()
- content.text = final_text
+ content.parts = [part]
event = mock.MagicMock()
event.is_final_response.return_value = True
@@ -56,6 +58,9 @@ def _make_mock_runner(
runner.agent = agent
runner.run_async = mock.MagicMock(side_effect=_async_gen)
runner.session_service = mock.MagicMock()
+ runner.session_service.get_session = mock.AsyncMock(
+ side_effect=Exception("Session not found"))
+ runner.session_service.create_session = mock.AsyncMock()
return runner
@@ -251,8 +256,8 @@ class TestSessionManagement(unittest.TestCase):
class TestResponseExtraction(unittest.TestCase):
"""Tests for extraction of the final response from the event stream."""
- def test_returns_none_when_no_final_response(self):
- """Agent emits only non-final events; inference should be None."""
+ def test_raises_when_no_final_response(self):
+ """Agent emits only non-final events; should raise ValueError."""
agent = _make_mock_agent()
# Build a runner that yields only non-final events
@@ -266,14 +271,14 @@ class TestResponseExtraction(unittest.TestCase):
runner.agent = agent
runner.run_async = mock.MagicMock(side_effect=_async_gen)
runner.session_service = mock.MagicMock()
+ runner.session_service.get_session =
mock.AsyncMock(side_effect=Exception())
+ runner.session_service.create_session = mock.AsyncMock()
handler = ADKAgentModelHandler(agent=agent)
- results = list(handler.run_inference(batch=["hello"], model=runner))
-
- self.assertEqual(len(results), 1)
- self.assertIsNone(results[0].inference)
+ with self.assertRaisesRegex(ValueError, "did not return a response"):
+ list(handler.run_inference(batch=["hello"], model=runner))
- def test_returns_none_when_final_event_has_no_content(self):
+ def test_raises_when_final_event_has_no_content(self):
agent = _make_mock_agent()
event = mock.MagicMock()
@@ -287,19 +292,22 @@ class TestResponseExtraction(unittest.TestCase):
runner.agent = agent
runner.run_async = mock.MagicMock(side_effect=_async_gen)
runner.session_service = mock.MagicMock()
+ runner.session_service.get_session =
mock.AsyncMock(side_effect=Exception())
+ runner.session_service.create_session = mock.AsyncMock()
handler = ADKAgentModelHandler(agent=agent)
- results = list(handler.run_inference(batch=["hello"], model=runner))
-
- self.assertIsNone(results[0].inference)
+ with self.assertRaisesRegex(ValueError, "did not return a response"):
+ list(handler.run_inference(batch=["hello"], model=runner))
def test_stops_after_first_final_response(self):
"""Multiple final events: only the first one's text should be used."""
agent = _make_mock_agent()
def _make_event(text: str):
+ part = mock.MagicMock()
+ part.text = text
content = mock.MagicMock()
- content.text = text
+ content.parts = [part]
event = mock.MagicMock()
event.is_final_response.return_value = True
event.content = content
@@ -313,6 +321,8 @@ class TestResponseExtraction(unittest.TestCase):
runner.agent = agent
runner.run_async = mock.MagicMock(side_effect=_async_gen)
runner.session_service = mock.MagicMock()
+ runner.session_service.get_session =
mock.AsyncMock(side_effect=Exception())
+ runner.session_service.create_session = mock.AsyncMock()
handler = ADKAgentModelHandler(agent=agent)
results = list(handler.run_inference(batch=["hi"], model=runner))
@@ -326,7 +336,7 @@ class TestResponseExtraction(unittest.TestCase):
result = asyncio.run(
ADKAgentModelHandler._invoke_agent(
- runner, "user", "session-1", mock.MagicMock()))
+ runner, "user", "session-1", "test_app", mock.MagicMock()))
self.assertEqual(result, "direct result")