This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch users/damccorm/session-cp
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 0d9b93fd7b5bdbcc56b3fdbce1d9f7013976f920
Author: Danny McCormick <[email protected]>
AuthorDate: Tue May 19 16:31:35 2026 -0400

    Make sure session creation happens before starting agent (#38477)
    
    * Make sure session creation happens before starting agent
    
    * fix var
    
    * Fix up
    
    * Fix text parsing
    
    * yapf
---
 .../ml/inference/agent_development_kit.py          | 36 +++++++++++---------
 .../ml/inference/agent_development_kit_test.py     | 38 ++++++++++++++--------
 2 files changed, 45 insertions(+), 29 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py 
b/sdks/python/apache_beam/ml/inference/agent_development_kit.py
index 1130598f06f..386955b0dfa 100644
--- a/sdks/python/apache_beam/ml/inference/agent_development_kit.py
+++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py
@@ -229,17 +229,6 @@ class ADKAgentModelHandler(ModelHandler[str | 
genai_Content,
     for element in batch:
       session_id: str = inference_args.get("session_id", str(uuid.uuid4()))
 
-      # Ensure a session exists for this invocation
-      try:
-        model.session_service.create_session(
-            app_name=self._app_name,
-            user_id=user_id,
-            session_id=session_id,
-        )
-      except sessions.SessionExistsError:
-        # It's okay if the session already exists for shared session IDs.
-        pass
-
       # Wrap plain strings in a Content object
       if isinstance(element, str):
         # pyrefly: ignore[bad-instantiation]
@@ -249,7 +238,8 @@ class ADKAgentModelHandler(ModelHandler[str | genai_Content,
         message = element
 
       agent_invocations.append(
-          self._invoke_agent(model, user_id, session_id, message))
+          self._invoke_agent(
+              model, user_id, session_id, self._app_name, message))
       elements_with_sessions.append(element)
 
     # Run all agent invocations concurrently
@@ -274,6 +264,7 @@ class ADKAgentModelHandler(ModelHandler[str | genai_Content,
       runner: "Runner",
       user_id: str,
       session_id: str,
+      app_name: str,
       message: genai_Content,
   ) -> Optional[str]:
     """Drives the ADK event loop and returns the final response text.
@@ -288,15 +279,30 @@ class ADKAgentModelHandler(ModelHandler[str | 
genai_Content,
       The text of the agent's final response, or ``None`` if the agent
       produced no final text response.
     """
+    # Check for your specific session ID
+    try:
+      # Attempt to get the specific session
+      await runner.session_service.get_session(session_id)
+    except Exception as e:
+      await runner.session_service.create_session(
+          app_name=app_name,
+          user_id=user_id,
+          session_id=session_id,
+      )
+
     async for event in runner.run_async(
         user_id=user_id,
         session_id=session_id,
         new_message=message,
     ):
       if event.is_final_response():
-        if event.content:
-          return event.content.text
-    return None
+        if event.content and event.content.parts:
+          return "".join([p.text for p in event.content.parts])
+        raise ValueError(
+            f"Agent {runner.agent.name} did not return a response, "
+            f"final event: {event}")
+
+    raise ValueError(f"Agent {runner.agent.name} did not return a response")
 
   def get_metrics_namespace(self) -> str:
     return "ADKAgentModelHandler"
diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py 
b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py
index 6d59bceb9d3..6c8b5c5b351 100644
--- a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py
+++ b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py
@@ -41,9 +41,11 @@ def _make_mock_runner(
     final_text: str = "Hello from agent",
 ) -> mock.MagicMock:
   """Returns a mock Runner whose run_async yields one final-response event."""
-  # Build a mock event that looks like a final response
+  part = mock.MagicMock()
+  part.text = final_text
+
   content = mock.MagicMock()
-  content.text = final_text
+  content.parts = [part]
 
   event = mock.MagicMock()
   event.is_final_response.return_value = True
@@ -56,6 +58,9 @@ def _make_mock_runner(
   runner.agent = agent
   runner.run_async = mock.MagicMock(side_effect=_async_gen)
   runner.session_service = mock.MagicMock()
+  runner.session_service.get_session = mock.AsyncMock(
+      side_effect=Exception("Session not found"))
+  runner.session_service.create_session = mock.AsyncMock()
   return runner
 
 
@@ -251,8 +256,8 @@ class TestSessionManagement(unittest.TestCase):
 
 class TestResponseExtraction(unittest.TestCase):
   """Tests for extraction of the final response from the event stream."""
-  def test_returns_none_when_no_final_response(self):
-    """Agent emits only non-final events; inference should be None."""
+  def test_raises_when_no_final_response(self):
+    """Agent emits only non-final events; should raise ValueError."""
     agent = _make_mock_agent()
 
     # Build a runner that yields only non-final events
@@ -266,14 +271,14 @@ class TestResponseExtraction(unittest.TestCase):
     runner.agent = agent
     runner.run_async = mock.MagicMock(side_effect=_async_gen)
     runner.session_service = mock.MagicMock()
+    runner.session_service.get_session = 
mock.AsyncMock(side_effect=Exception())
+    runner.session_service.create_session = mock.AsyncMock()
 
     handler = ADKAgentModelHandler(agent=agent)
-    results = list(handler.run_inference(batch=["hello"], model=runner))
-
-    self.assertEqual(len(results), 1)
-    self.assertIsNone(results[0].inference)
+    with self.assertRaisesRegex(ValueError, "did not return a response"):
+      list(handler.run_inference(batch=["hello"], model=runner))
 
-  def test_returns_none_when_final_event_has_no_content(self):
+  def test_raises_when_final_event_has_no_content(self):
     agent = _make_mock_agent()
 
     event = mock.MagicMock()
@@ -287,19 +292,22 @@ class TestResponseExtraction(unittest.TestCase):
     runner.agent = agent
     runner.run_async = mock.MagicMock(side_effect=_async_gen)
     runner.session_service = mock.MagicMock()
+    runner.session_service.get_session = 
mock.AsyncMock(side_effect=Exception())
+    runner.session_service.create_session = mock.AsyncMock()
 
     handler = ADKAgentModelHandler(agent=agent)
-    results = list(handler.run_inference(batch=["hello"], model=runner))
-
-    self.assertIsNone(results[0].inference)
+    with self.assertRaisesRegex(ValueError, "did not return a response"):
+      list(handler.run_inference(batch=["hello"], model=runner))
 
   def test_stops_after_first_final_response(self):
     """Multiple final events: only the first one's text should be used."""
     agent = _make_mock_agent()
 
     def _make_event(text: str):
+      part = mock.MagicMock()
+      part.text = text
       content = mock.MagicMock()
-      content.text = text
+      content.parts = [part]
       event = mock.MagicMock()
       event.is_final_response.return_value = True
       event.content = content
@@ -313,6 +321,8 @@ class TestResponseExtraction(unittest.TestCase):
     runner.agent = agent
     runner.run_async = mock.MagicMock(side_effect=_async_gen)
     runner.session_service = mock.MagicMock()
+    runner.session_service.get_session = 
mock.AsyncMock(side_effect=Exception())
+    runner.session_service.create_session = mock.AsyncMock()
 
     handler = ADKAgentModelHandler(agent=agent)
     results = list(handler.run_inference(batch=["hi"], model=runner))
@@ -326,7 +336,7 @@ class TestResponseExtraction(unittest.TestCase):
 
     result = asyncio.run(
         ADKAgentModelHandler._invoke_agent(
-            runner, "user", "session-1", mock.MagicMock()))
+            runner, "user", "session-1", "test_app", mock.MagicMock()))
     self.assertEqual(result, "direct result")
 
 

Reply via email to