This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch users/damccorm/adkHandler
in repository https://gitbox.apache.org/repos/asf/beam.git

commit 3788a6b7cdf9c45179bf080a7f86dc1f52e0700b
Author: Danny Mccormick <[email protected]>
AuthorDate: Mon Mar 23 10:36:26 2026 -0400

    Add ADK model handler
---
 .../ml/inference/agent_development_kit.py          | 284 ++++++++++++++++
 .../ml/inference/agent_development_kit_test.py     | 356 +++++++++++++++++++++
 2 files changed, 640 insertions(+)

diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit.py 
b/sdks/python/apache_beam/ml/inference/agent_development_kit.py
new file mode 100644
index 00000000000..59dc0cfb2e0
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/agent_development_kit.py
@@ -0,0 +1,284 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""ModelHandler for running agents built with the Google Agent Development Kit.
+
+This module provides :class:`ADKAgentModelHandler`, a Beam
+:class:`~apache_beam.ml.inference.base.ModelHandler` that wraps an ADK
+:class:`google.adk.agents.llm_agent.LlmAgent` so it can be used with the
+:class:`~apache_beam.ml.inference.base.RunInference` transform.
+
+**NOTE:** This API and its implementation are under development and do not
+provide backward compatibility guarantees.
+
+Typical usage::
+
+    import apache_beam as beam
+    from apache_beam.ml.inference.base import RunInference
+    from apache_beam.ml.inference.agent_development_kit import 
ADKAgentModelHandler
+    from google.adk.agents import LlmAgent
+
+    agent = LlmAgent(
+        name="my_agent",
+        model="gemini-2.0-flash",
+        instruction="You are a helpful assistant.",
+    )
+
+    with beam.Pipeline() as p:
+        results = (
+            p
+            | beam.Create(["What is the capital of France?"])
+            | RunInference(ADKAgentModelHandler(agent=agent))
+        )
+
+If your agent contains state that is not picklable (e.g. tool closures that
+capture unpicklable objects), pass a zero-arg factory callable instead::
+
+    handler = ADKAgentModelHandler(agent=lambda: LlmAgent(...))
+
+"""
+
+import asyncio
+import logging
+import uuid
+from collections.abc import Callable
+from collections.abc import Iterable
+from collections.abc import Sequence
+from typing import Any
+from typing import Optional
+from typing import Union
+
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+
+try:
+  from google.adk.agents import Agent
+  from google.adk.runners import Runner
+  from google.adk.sessions import BaseSessionService
+  from google.adk.sessions import InMemorySessionService
+  from google.genai import types as genai_types
+  ADK_AVAILABLE = True
+except ImportError:
+  ADK_AVAILABLE = False
+
+LOGGER = logging.getLogger("ADKAgentModelHandler")
+
+# Type alias for an agent or factory that produces one
+_AgentOrFactory = Union["Agent", Callable[[], "Agent"]]
+
+
+class ADKAgentModelHandler(ModelHandler[Union[str, Any], PredictionResult,
+                                        "Runner"]):
+  """ModelHandler for running ADK agents with the Beam RunInference transform.
+
+  Accepts either a fully constructed :class:`google.adk.agents.Agent` or a
+  zero-arg factory callable that produces one. The factory form is useful when
+  the agent contains state that is not picklable and therefore cannot be
+  serialized alongside the pipeline graph.
+
+  Each call to :meth:`run_inference` invokes the agent once per element in the
+  batch. By default every invocation uses a fresh, isolated session 
(stateless).
+  Stateful multi-turn conversations can be achieved by passing a ``session_id``
+  key inside ``inference_args``; elements sharing the same ``session_id`` will
+  continue the same conversation history.
+
+  Args:
+    agent: A pre-constructed :class:`~google.adk.agents.Agent` instance, or a
+      zero-arg callable that returns one. The callable form defers agent
+      construction to worker ``load_model`` time, which is useful when the
+      agent cannot be serialized.
+    app_name: The ADK application name used to namespace sessions. Defaults to
+      ``"beam_inference"``.
+    session_service_factory: Optional zero-arg callable returning a
+      :class:`~google.adk.sessions.BaseSessionService`. When ``None``, an
+      :class:`~google.adk.sessions.InMemorySessionService` is created
+      automatically.
+    min_batch_size: Optional minimum batch size.
+    max_batch_size: Optional maximum batch size.
+    max_batch_duration_secs: Optional maximum time to buffer a batch before
+      emitting; used in streaming contexts.
+    max_batch_weight: Optional maximum total weight of a batch.
+    element_size_fn: Optional function that returns the size (weight) of an
+      element.
+  """
+
+  def __init__(
+      self,
+      agent: _AgentOrFactory,
+      app_name: str = "beam_inference",
+      session_service_factory: Optional[Callable[[], "BaseSessionService"]] =
+      None,
+      *,
+      min_batch_size: Optional[int] = None,
+      max_batch_size: Optional[int] = None,
+      max_batch_duration_secs: Optional[int] = None,
+      max_batch_weight: Optional[int] = None,
+      element_size_fn: Optional[Callable[[Any], int]] = None,
+      **kwargs):
+    if not ADK_AVAILABLE:
+      raise ImportError(
+          "google-adk is required to use ADKAgentModelHandler. "
+          "Install it with: pip install google-adk")
+
+    if agent is None:
+      raise ValueError("'agent' must be an Agent instance or a callable.")
+
+    self._agent_or_factory = agent
+    self._app_name = app_name
+    self._session_service_factory = session_service_factory
+
+    super().__init__(
+        min_batch_size=min_batch_size,
+        max_batch_size=max_batch_size,
+        max_batch_duration_secs=max_batch_duration_secs,
+        max_batch_weight=max_batch_weight,
+        element_size_fn=element_size_fn,
+        **kwargs)
+
+  def load_model(self) -> "Runner":
+    """Instantiates the ADK Runner on the worker.
+
+    Resolves the agent (calling the factory if a callable was provided), then
+    creates a :class:`~google.adk.runners.Runner` backed by the configured
+    session service.
+
+    Returns:
+      A fully initialised :class:`~google.adk.runners.Runner`.
+    """
+    if callable(self._agent_or_factory) and not isinstance(
+        self._agent_or_factory, Agent):
+      agent = self._agent_or_factory()
+    else:
+      agent = self._agent_or_factory
+
+    if self._session_service_factory is not None:
+      session_service = self._session_service_factory()
+    else:
+      session_service = InMemorySessionService()
+
+    runner = Runner(
+        agent=agent,
+        app_name=self._app_name,
+        session_service=session_service,
+    )
+    LOGGER.info(
+        "Loaded ADK Runner for agent '%s' (app_name='%s')",
+        agent.name,
+        self._app_name,
+    )
+    return runner
+
+  def run_inference(
+      self,
+      batch: Sequence[Union[str, Any]],
+      model: "Runner",
+      inference_args: Optional[dict[str, Any]] = None,
+  ) -> Iterable[PredictionResult]:
+    """Runs the ADK agent on each element in the batch.
+
+    Each element is sent to the agent as a new user turn. The final response
+    text from the agent is returned as the ``inference`` field of a
+    :class:`~apache_beam.ml.inference.base.PredictionResult`.
+
+    Args:
+      batch: A sequence of inputs, each of which is either a ``str`` (the user
+        message text) or a :class:`google.genai.types.Content` object (for
+        richer multi-part messages).
+      model: The :class:`~google.adk.runners.Runner` returned by
+        :meth:`load_model`.
+      inference_args: Optional dict of extra arguments. Supported keys:
+
+        - ``"session_id"`` (:class:`str`): If supplied, all elements in this
+          batch share this session ID, enabling stateful multi-turn
+          conversations. If omitted, each element receives a unique auto-
+          generated session ID.
+        - ``"user_id"`` (:class:`str`): The user identifier to pass to the
+          runner. Defaults to ``"beam_user"``.
+
+    Returns:
+      An iterable of :class:`~apache_beam.ml.inference.base.PredictionResult`,
+      one per input element.
+    """
+    if inference_args is None:
+      inference_args = {}
+
+    user_id: str = inference_args.get("user_id", "beam_user")
+
+    results = []
+    for element in batch:
+      session_id: str = inference_args.get("session_id", str(uuid.uuid4()))
+
+      # Ensure a session exists for this invocation
+      model.session_service.create_session(
+          app_name=self._app_name,
+          user_id=user_id,
+          session_id=session_id,
+      )
+
+      # Wrap plain strings in a Content object
+      if isinstance(element, str):
+        message = genai_types.Content(
+            role="user", parts=[genai_types.Part(text=element)])
+      else:
+        # Assume the caller has already constructed a types.Content object
+        message = element
+
+      response_text = asyncio.run(
+          self._invoke_agent(model, user_id, session_id, message))
+
+      results.append(
+          PredictionResult(
+              example=element,
+              inference=response_text,
+              model_id=model.agent.name,
+          ))
+
+    return results
+
+  @staticmethod
+  async def _invoke_agent(
+      runner: "Runner",
+      user_id: str,
+      session_id: str,
+      message: Any,
+  ) -> Optional[str]:
+    """Drives the ADK event loop and returns the final response text.
+
+    Args:
+      runner: The ADK Runner to invoke.
+      user_id: The user ID for this invocation.
+      session_id: The session ID for this invocation.
+      message: The :class:`google.genai.types.Content` to send.
+
+    Returns:
+      The text of the agent's final response, or ``None`` if the agent
+      produced no final text response.
+    """
+    final_text: Optional[str] = None
+    async for event in runner.run_async(
+        user_id=user_id,
+        session_id=session_id,
+        new_message=message,
+    ):
+      if event.is_final_response():
+        if event.content and event.content.parts:
+          final_text = event.content.parts[0].text
+        break
+    return final_text
+
+  def get_metrics_namespace(self) -> str:
+    return "ADKAgentModelHandler"
diff --git a/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py 
b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py
new file mode 100644
index 00000000000..7bd77c52ff1
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/agent_development_kit_test.py
@@ -0,0 +1,356 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# pytype: skip-file
+
+import asyncio
+import unittest
+from unittest import mock
+
+try:
+  from apache_beam.ml.inference.agent_development_kit import 
ADKAgentModelHandler
+  from apache_beam.ml.inference.base import PredictionResult
+except ImportError:
+  raise unittest.SkipTest('google-adk dependencies are not installed')
+
+
+def _make_mock_agent(name: str = "test_agent") -> mock.MagicMock:
+  """Returns a mock that quacks like a google.adk.agents.Agent."""
+  agent = mock.MagicMock()
+  agent.name = name
+  return agent
+
+
+def _make_mock_runner(
+    agent: mock.MagicMock,
+    final_text: str = "Hello from agent",
+) -> mock.MagicMock:
+  """Returns a mock Runner whose run_async yields one final-response event."""
+  # Build a mock event that looks like a final response
+  part = mock.MagicMock()
+  part.text = final_text
+
+  content = mock.MagicMock()
+  content.parts = [part]
+
+  event = mock.MagicMock()
+  event.is_final_response.return_value = True
+  event.content = content
+
+  async def _async_gen(*args, **kwargs):
+    yield event
+
+  runner = mock.MagicMock()
+  runner.agent = agent
+  runner.run_async = mock.MagicMock(side_effect=_async_gen)
+  runner.session_service = mock.MagicMock()
+  return runner
+
+
+# ---------------------------------------------------------------------------
+# Helper: patch ADK imports inside the module under test so tests work even
+# when google-adk is installed (avoids constructing real ADK objects).
+# ---------------------------------------------------------------------------
+_MODULE = "apache_beam.ml.inference.agent_development_kit"
+
+
+class TestADKAgentModelHandlerInit(unittest.TestCase):
+  """Tests for __init__ argument validation."""
+
+  def test_raises_if_agent_is_none(self):
+    with self.assertRaises((ValueError, TypeError)):
+      ADKAgentModelHandler(agent=None)  # type: ignore[arg-type]
+
+  def test_accepts_agent_object(self):
+    agent = _make_mock_agent()
+    handler = ADKAgentModelHandler(agent=agent)
+    self.assertEqual(handler._agent_or_factory, agent)
+
+  def test_accepts_agent_factory_callable(self):
+    agent = _make_mock_agent()
+    factory = lambda: agent
+    handler = ADKAgentModelHandler(agent=factory)
+    self.assertTrue(callable(handler._agent_or_factory))
+
+  def test_default_app_name(self):
+    agent = _make_mock_agent()
+    handler = ADKAgentModelHandler(agent=agent)
+    self.assertEqual(handler._app_name, "beam_inference")
+
+  def test_custom_app_name(self):
+    agent = _make_mock_agent()
+    handler = ADKAgentModelHandler(agent=agent, app_name="my_app")
+    self.assertEqual(handler._app_name, "my_app")
+
+  def test_metrics_namespace(self):
+    agent = _make_mock_agent()
+    handler = ADKAgentModelHandler(agent=agent)
+    self.assertEqual(handler.get_metrics_namespace(), "ADKAgentModelHandler")
+
+
+class TestLoadModel(unittest.TestCase):
+  """Tests for load_model / Runner construction."""
+
+  @mock.patch(f"{_MODULE}.Runner")
+  @mock.patch(f"{_MODULE}.InMemorySessionService")
+  def test_load_model_with_agent_object(
+      self, mock_session_cls, mock_runner_cls):
+    agent = _make_mock_agent()
+    handler = ADKAgentModelHandler(agent=agent, app_name="test_app")
+
+    handler.load_model()
+
+    mock_session_cls.assert_called_once()
+    mock_runner_cls.assert_called_once_with(
+        agent=agent,
+        app_name="test_app",
+        session_service=mock_session_cls.return_value,
+    )
+
+  @mock.patch(f"{_MODULE}.Runner")
+  @mock.patch(f"{_MODULE}.InMemorySessionService")
+  def test_load_model_calls_factory(self, mock_session_cls, mock_runner_cls):
+    agent = _make_mock_agent()
+    factory = mock.MagicMock(return_value=agent)
+
+    handler = ADKAgentModelHandler(agent=factory)
+    handler.load_model()
+
+    factory.assert_called_once()
+    mock_runner_cls.assert_called_once_with(
+        agent=agent,
+        app_name="beam_inference",
+        session_service=mock_session_cls.return_value,
+    )
+
+  @mock.patch(f"{_MODULE}.Runner")
+  def test_load_model_uses_custom_session_service(self, mock_runner_cls):
+    agent = _make_mock_agent()
+    custom_session_service = mock.MagicMock()
+    session_factory = mock.MagicMock(return_value=custom_session_service)
+
+    handler = ADKAgentModelHandler(
+        agent=agent, session_service_factory=session_factory)
+    handler.load_model()
+
+    session_factory.assert_called_once()
+    mock_runner_cls.assert_called_once_with(
+        agent=agent,
+        app_name="beam_inference",
+        session_service=custom_session_service,
+    )
+
+
+class TestRunInference(unittest.TestCase):
+  """Tests for run_inference output and batching."""
+
+  def test_string_input_returns_prediction_result(self):
+    agent = _make_mock_agent()
+    runner = _make_mock_runner(agent, final_text="Paris")
+
+    handler = ADKAgentModelHandler(agent=agent)
+    results = list(
+        handler.run_inference(
+            batch=["What is the capital of France?"], model=runner))
+
+    self.assertEqual(len(results), 1)
+    pr = results[0]
+    self.assertIsInstance(pr, PredictionResult)
+    self.assertEqual(pr.example, "What is the capital of France?")
+    self.assertEqual(pr.inference, "Paris")
+    self.assertEqual(pr.model_id, "test_agent")
+
+  def test_batch_of_strings(self):
+    agent = _make_mock_agent()
+    runner = _make_mock_runner(agent, final_text="answer")
+
+    handler = ADKAgentModelHandler(agent=agent)
+    results = list(
+        handler.run_inference(batch=["q1", "q2", "q3"], model=runner))
+
+    self.assertEqual(len(results), 3)
+    self.assertEqual([r.example for r in results], ["q1", "q2", "q3"])
+
+  def test_content_object_input(self):
+    """Non-string inputs (types.Content) are passed through unchanged."""
+    agent = _make_mock_agent()
+    runner = _make_mock_runner(agent, final_text="Berlin")
+
+    content_input = mock.MagicMock()  # simulates types.Content
+
+    handler = ADKAgentModelHandler(agent=agent)
+    results = list(handler.run_inference(batch=[content_input], model=runner))
+
+    self.assertEqual(len(results), 1)
+    self.assertEqual(results[0].example, content_input)
+    self.assertEqual(results[0].inference, "Berlin")
+
+  def test_none_inference_args_uses_defaults(self):
+    agent = _make_mock_agent()
+    runner = _make_mock_runner(agent)
+
+    handler = ADKAgentModelHandler(agent=agent)
+    results = list(
+        handler.run_inference(
+            batch=["hello"], model=runner, inference_args=None))
+    self.assertEqual(len(results), 1)
+
+  def test_custom_user_id_passed_to_runner(self):
+    agent = _make_mock_agent()
+    runner = _make_mock_runner(agent)
+
+    handler = ADKAgentModelHandler(agent=agent)
+    handler.run_inference(
+        batch=["hi"],
+        model=runner,
+        inference_args={"user_id": "custom_user"},
+    )
+
+    call_kwargs = runner.run_async.call_args[1]
+    self.assertEqual(call_kwargs["user_id"], "custom_user")
+
+
+class TestSessionManagement(unittest.TestCase):
+  """Tests for session creation and session_id handling."""
+
+  def test_each_element_gets_unique_session_by_default(self):
+    agent = _make_mock_agent()
+    runner = _make_mock_runner(agent)
+
+    handler = ADKAgentModelHandler(agent=agent)
+    handler.run_inference(batch=["a", "b", "c"], model=runner)
+
+    # create_session should have been called 3 times with distinct session IDs
+    calls = runner.session_service.create_session.call_args_list
+    self.assertEqual(len(calls), 3)
+    session_ids = [c[1]["session_id"] for c in calls]
+    self.assertEqual(len(set(session_ids)), 3, "Expected unique session IDs")
+
+  def test_shared_session_id_from_inference_args(self):
+    agent = _make_mock_agent()
+    runner = _make_mock_runner(agent)
+
+    handler = ADKAgentModelHandler(agent=agent)
+    handler.run_inference(
+        batch=["turn1", "turn2"],
+        model=runner,
+        inference_args={"session_id": "my-session"},
+    )
+
+    calls = runner.session_service.create_session.call_args_list
+    session_ids = [c[1]["session_id"] for c in calls]
+    self.assertTrue(
+        all(sid == "my-session" for sid in session_ids),
+        "All elements should share the provided session_id",
+    )
+
+  def test_session_created_with_correct_app_name(self):
+    agent = _make_mock_agent()
+    runner = _make_mock_runner(agent)
+
+    handler = ADKAgentModelHandler(agent=agent, app_name="my_app")
+    handler.run_inference(batch=["hello"], model=runner)
+
+    call_kwargs = runner.session_service.create_session.call_args[1]
+    self.assertEqual(call_kwargs["app_name"], "my_app")
+
+
+class TestResponseExtraction(unittest.TestCase):
+  """Tests for extraction of the final response from the event stream."""
+
+  def test_returns_none_when_no_final_response(self):
+    """Agent emits only non-final events; inference should be None."""
+    agent = _make_mock_agent()
+
+    # Build a runner that yields only non-final events
+    non_final_event = mock.MagicMock()
+    non_final_event.is_final_response.return_value = False
+
+    async def _async_gen(*args, **kwargs):
+      yield non_final_event
+
+    runner = mock.MagicMock()
+    runner.agent = agent
+    runner.run_async = mock.MagicMock(side_effect=_async_gen)
+    runner.session_service = mock.MagicMock()
+
+    handler = ADKAgentModelHandler(agent=agent)
+    results = list(handler.run_inference(batch=["hello"], model=runner))
+
+    self.assertEqual(len(results), 1)
+    self.assertIsNone(results[0].inference)
+
+  def test_returns_none_when_final_event_has_no_content(self):
+    agent = _make_mock_agent()
+
+    event = mock.MagicMock()
+    event.is_final_response.return_value = True
+    event.content = None
+
+    async def _async_gen(*args, **kwargs):
+      yield event
+
+    runner = mock.MagicMock()
+    runner.agent = agent
+    runner.run_async = mock.MagicMock(side_effect=_async_gen)
+    runner.session_service = mock.MagicMock()
+
+    handler = ADKAgentModelHandler(agent=agent)
+    results = list(handler.run_inference(batch=["hello"], model=runner))
+
+    self.assertIsNone(results[0].inference)
+
+  def test_stops_after_first_final_response(self):
+    """Multiple final events: only the first one's text should be used."""
+    agent = _make_mock_agent()
+
+    def _make_event(text: str):
+      part = mock.MagicMock()
+      part.text = text
+      content = mock.MagicMock()
+      content.parts = [part]
+      event = mock.MagicMock()
+      event.is_final_response.return_value = True
+      event.content = content
+      return event
+
+    async def _async_gen(*args, **kwargs):
+      yield _make_event("first")
+      yield _make_event("second")
+
+    runner = mock.MagicMock()
+    runner.agent = agent
+    runner.run_async = mock.MagicMock(side_effect=_async_gen)
+    runner.session_service = mock.MagicMock()
+
+    handler = ADKAgentModelHandler(agent=agent)
+    results = list(handler.run_inference(batch=["hi"], model=runner))
+
+    self.assertEqual(results[0].inference, "first")
+
+  def test_invoke_agent_static_method_directly(self):
+    """Unit test the async _invoke_agent helper directly."""
+    agent = _make_mock_agent()
+    runner = _make_mock_runner(agent, final_text="direct result")
+
+    result = asyncio.run(
+        ADKAgentModelHandler._invoke_agent(
+            runner, "user", "session-1", mock.MagicMock()))
+    self.assertEqual(result, "direct result")
+
+
+if __name__ == '__main__':
+  unittest.main()

Reply via email to