gemini-code-assist[bot] commented on code in PR #39074:
URL: https://github.com/apache/beam/pull/39074#discussion_r3461219111
##########
sdks/python/apache_beam/ml/inference/agent_development_kit.py:
##########
@@ -259,6 +389,26 @@ async def _run_concurrently():
return results
+ def _update_agent_port(self, agent: "Agent", port: int):
+ if ADK_AVAILABLE:
+ from google.adk.models.lite_llm import LiteLlm
+ if hasattr(agent, 'model') and isinstance(agent.model, LiteLlm):
+ agent.model = LiteLlm(
+ model=agent.model.model,
+ api_base=f"http://localhost:{port}/v1"
+ )
+ if hasattr(agent, 'tools'):
+ for tool in agent.tools:
+ if hasattr(tool, 'agent'):
+ self._update_agent_port(tool.agent, port)
+ elif isinstance(tool, Agent):
+ self._update_agent_port(tool, port)
Review Comment:

Similarly, if `agent.tools` is `None`, iterating over it here will raise a
`TypeError`. We should safely check if `agent.tools` is not `None` before
iterating.
```suggestion
if getattr(agent, 'tools', None) is not None:
for tool in agent.tools:
if hasattr(tool, 'agent'):
self._update_agent_port(tool.agent, port)
elif isinstance(tool, Agent):
self._update_agent_port(tool, port)
```
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -406,6 +408,225 @@ def should_garbage_collect_on_timeout(self) -> bool:
return self.share_model_across_processes()
+class SubprocessModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC):
+ """Base class for model handlers that spin up a subprocess server."""
+ @abstractmethod
+ def get_port(self, model: ModelT) -> int:
+ """Returns the port the subprocess server is listening on."""
+ pass
+
+ @abstractmethod
+ def get_model_name(self) -> str:
+ """Returns the model name."""
+ pass
+
+ @abstractmethod
+ def check_connectivity(self, model: ModelT) -> None:
+ """Checks connectivity to the server and attempts to recover/mark for
restart."""
+ pass
+
+
+class SubProcessModelServer:
+ """Manages the lifecycle of a generic subprocess model server."""
+ def __init__(self, handler_path: str, model_name: str, port: int = None,
temp_dir: tempfile.TemporaryDirectory = None):
+ self._handler_path = handler_path
+ self._model_name = model_name
+ self._port = port
+ self._temp_dir = temp_dir
+ self._process = None
+ self._server_started = False
+ self._server_process_lock = threading.RLock()
+ self.start_server()
+
+ def start_server(self, retries=3):
+ with self._server_process_lock:
+ if not self._server_started:
Review Comment:

If the subprocess server crashes, `self._server_started` remains `True`. If
`get_server_port()` is called subsequently, it will return the port without
restarting the server because `self._server_started` is still `True`. We should
check if the process has exited and reset `self._server_started = False` at the
beginning of `start_server` to ensure recovery.
```python
def start_server(self, retries=3):
with self._server_process_lock:
if self._process and self._process.poll() is not None:
self._server_started = False
if not self._server_started:
```
##########
sdks/python/apache_beam/ml/inference/agent_development_kit.py:
##########
@@ -181,13 +264,33 @@ def load_model(self) -> "Runner":
app_name=self._app_name,
session_service=session_service,
)
+
+ if underlying_model is not None:
+ runner._underlying_model = underlying_model
+
LOGGER.info(
"Loaded ADK Runner for agent '%s' (app_name='%s')",
agent.name,
self._app_name,
)
return runner
+ def _set_agent_model(self, agent: "Agent", model: Any, is_root: bool =
False):
+ if is_root:
+ if isinstance(agent.model, BeamPlaceholderModel) or agent.model is None:
+ agent.model = model
+ else:
+ if isinstance(agent.model, BeamPlaceholderModel):
+ agent.model = model
+
+ # Speculative propagation to subagents/tools
+ if hasattr(agent, 'tools'):
+ for tool in agent.tools:
+ if hasattr(tool, 'agent'):
+ self._set_agent_model(tool.agent, model, is_root=False)
+ elif isinstance(tool, Agent):
+ self._set_agent_model(tool, model, is_root=False)
Review Comment:

If `agent.tools` is `None` (which is the default in ADK when no tools are
provided), calling `hasattr(agent, 'tools')` will return `True`, but attempting
to iterate over it will raise a `TypeError: 'NoneType' object is not iterable`.
We should use `getattr(agent, 'tools', None) is not None` to safely guard
against this.
```suggestion
if getattr(agent, 'tools', None) is not None:
for tool in agent.tools:
if hasattr(tool, 'agent'):
self._set_agent_model(tool.agent, model, is_root=False)
elif isinstance(tool, Agent):
self._set_agent_model(tool, model, is_root=False)
```
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -406,6 +408,225 @@ def should_garbage_collect_on_timeout(self) -> bool:
return self.share_model_across_processes()
+class SubprocessModelHandler(ModelHandler[ExampleT, PredictionT, ModelT], ABC):
+ """Base class for model handlers that spin up a subprocess server."""
+ @abstractmethod
+ def get_port(self, model: ModelT) -> int:
+ """Returns the port the subprocess server is listening on."""
+ pass
+
+ @abstractmethod
+ def get_model_name(self) -> str:
+ """Returns the model name."""
+ pass
+
+ @abstractmethod
+ def check_connectivity(self, model: ModelT) -> None:
+ """Checks connectivity to the server and attempts to recover/mark for
restart."""
+ pass
+
+
+class SubProcessModelServer:
+ """Manages the lifecycle of a generic subprocess model server."""
+ def __init__(self, handler_path: str, model_name: str, port: int = None,
temp_dir: tempfile.TemporaryDirectory = None):
+ self._handler_path = handler_path
+ self._model_name = model_name
+ self._port = port
+ self._temp_dir = temp_dir
+ self._process = None
+ self._server_started = False
+ self._server_process_lock = threading.RLock()
+ self.start_server()
+
+ def start_server(self, retries=3):
+ with self._server_process_lock:
+ if not self._server_started:
+ if self._process:
+ logging.info("Terminating existing generic subprocess model server
before restart")
+ try:
+ self._process.terminate()
+ self._process.wait(timeout=5)
+ except Exception:
+ try:
+ self._process.kill()
+ except Exception:
+ pass
+ self._process = None
+ self._port = None
+
+ from apache_beam.utils import subprocess_server
+ if self._port is None:
+ self._port, = subprocess_server.pick_port(None)
+
+ cmd = [
+ sys.executable,
+ '-m',
+ 'apache_beam.ml.inference.subprocess_server',
+ '--handler_path',
+ self._handler_path,
+ '--port',
+ str(self._port),
+ ]
+ logging.info("Starting generic model server with %s", cmd)
+ self._process = subprocess.Popen(
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+
+ # Emit the output of this command as info level logging.
+ def log_stdout():
+ line = self._process.stdout.readline()
+ while line:
+ logging.info(line.decode(errors='backslashreplace').rstrip())
+ line = self._process.stdout.readline()
+
+ t = threading.Thread(target=log_stdout)
+ t.daemon = True
+ t.start()
+
+ self.check_connectivity(retries)
+
+ def get_server_port(self) -> int:
+ if not self._server_started:
+ self.start_server()
+ return self._port
+
+ def check_connectivity(self, retries=3):
+ import urllib.request
+ import urllib.error
+
+ url = f"http://localhost:{self._port}/v1/models"
+ attempts = 0
+ max_attempts = 12 # 12 * 5s = 60s timeout
+ while self._process.poll() is None and attempts < max_attempts:
+ try:
+ # Use standard library to check connectivity to avoid extra
dependencies
+ req = urllib.request.Request(url, method="GET")
+ with urllib.request.urlopen(req, timeout=5) as response:
+ if response.status == 200:
+ self._server_started = True
+ return
+ except urllib.error.URLError:
+ pass
+ except Exception as e:
+ logging.warning("Error checking connectivity: %s", e)
+ attempts += 1
+ time.sleep(5)
+
+ if retries == 0:
+ self._server_started = False
+ raise Exception(
+ "Failed to start generic subprocess server, polling process exited
with code " +
+ f"{self._process.poll()}. Next time a request is tried, the server
will be restarted"
+ )
+ else:
+ self.start_server(retries - 1)
Review Comment:

There are two critical issues in `check_connectivity` and `start_server`:
1. **Race Condition / Redundant Restarts**: If multiple threads concurrently
call `check_connectivity` due to an inference failure, they will all call
`start_server`. While `start_server` is protected by
`self._server_process_lock`, once the first thread finishes restarting the
server and releases the lock, the waiting threads will acquire the lock and
restart the server again, terminating the newly started healthy server. Since
`self._server_process_lock` is an `RLock`, we should wrap the entire
`check_connectivity` method with this lock so that subsequent threads
immediately see the healthy server and return without redundant restarts.
2. **No-op Restart Loop**: If the server crashes, `self._server_started`
remains `True`. When `check_connectivity` is called, the polling loop is
skipped because `self._process.poll()` is not `None`. It then calls
`start_server(retries - 1)`. However, because `self._server_started` is still
`True`, `start_server` skips the spawning block entirely and just calls
`check_connectivity` again. This results in a loop of no-op retries that never
actually restarts the process. We must reset `self._server_started = False`
before calling `start_server` to ensure the server is actually restarted.
```python
def check_connectivity(self, retries=3):
with self._server_process_lock:
import urllib.request
import urllib.error
url = f"http://localhost:{self._port}/v1/models"
attempts = 0
max_attempts = 12 # 12 * 5s = 60s timeout
while self._process.poll() is None and attempts < max_attempts:
try:
# Use standard library to check connectivity to avoid extra
dependencies
req = urllib.request.Request(url, method="GET")
with urllib.request.urlopen(req, timeout=5) as response:
if response.status == 200:
self._server_started = True
return
except urllib.error.URLError:
pass
except Exception as e:
logging.warning("Error checking connectivity: %s", e)
attempts += 1
time.sleep(5)
if retries == 0:
self._server_started = False
raise Exception(
"Failed to start generic subprocess server, polling process
exited with code " +
f"{self._process.poll()}. Next time a request is tried, the
server will be restarted"
)
else:
self._server_started = False
self.start_server(retries - 1)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]