tvalentyn commented on code in PR #27857:
URL: https://github.com/apache/beam/pull/27857#discussion_r1291713448


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -243,70 +244,273 @@ def share_model_across_processes(self) -> bool:
     return False
 
 
+class _ModelManager:
+  """
+  A class for efficiently managing copies of multiple models. Will load a
+  single copy of each model into a multi_process_shared object and then
+  return a lookup key for that object. Optionally takes in a max_models
+  parameter, if that is set it will only hold that many models in memory at
+  once before evicting one (using LRU logic).
+  """
+  def __init__(
+      self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None):
+    """
+    Args:
+      mh_map: A map from keys to model handlers which can be used to load a
+        model.
+      max_models: The maximum number of models to load at any given time
+        before evicting 1 from memory (using LRU logic). Leave as None to
+        allow unlimited models.
+    """
+    self._max_models = max_models
+    self._mh_map: Dict[str, ModelHandler] = mh_map
+    self._proxy_map: Dict[str, str] = {}
+    self._tag_map: Dict[
+        str, multi_process_shared.MultiProcessShared] = OrderedDict()
+
+  def load(self, key: str) -> str:
+    """
+    Loads the appropriate model for the given key into memory.
+    Args:
+      key: the key associated with the model we'd like to load.
+    Returns:
+      the tag we can use to access the model using multi_process_shared.py.
+    """
+    # Map the key for a model to a unique tag that will persist until the model
+    # is released. This needs to be unique between releasing/reacquiring th
+    # model because otherwise the ProxyManager will try to reuse the model that
+    # has been released and deleted.
+    if key in self._tag_map:
+      self._tag_map.move_to_end(key)
+    else:
+      self._tag_map[key] = uuid.uuid4().hex
+
+    tag = self._tag_map[key]
+    mh = self._mh_map[key]
+
+    if self._max_models is not None and self._max_models < len(self._tag_map):
+      # If we're about to exceed our LRU size, release the last used model.
+      tag_to_remove = self._tag_map.popitem(last=False)[1]
+      shared_handle, model_to_remove = self._proxy_map[tag_to_remove]
+      shared_handle.release(model_to_remove)
+
+    # Load the new model
+    shared_handle = multi_process_shared.MultiProcessShared(
+        mh.load_model, tag=tag)
+    model_reference = shared_handle.acquire()
+    self._proxy_map[tag] = (shared_handle, model_reference)
+
+    return tag
+
+  def increment_max_models(self, increment: int):
+    """
+    Increments the number of models that this instance of a _ModelManager is
+    able to hold.
+    Args:
+      increment: the amount by which we are incrementing the number of models.
+    """
+    if self._max_models is None:
+      raise ValueError(
+          "Cannot increment max_models if self._max_models is None (unlimited" 
+
+          " models mode).")
+    self._max_models += increment
+
+
+# Use a dataclass instead of named tuple because NamedTuples and generics don't
+# mix well across the board for all versions:
+# https://github.com/python/typing/issues/653
+class KeyMhMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]):

Review Comment:
   I agree it's less precise; but ultimately users care about the key->model 
mapping, and MH is an intermediary. I was trying to come up with a more 
readable and speakable name , esp since it's public api; also it's more common 
in python to cap all letters in the acronym (e.g. HTTP vs Http, but i i'd 
rather have MhM than MHM in this case i think...). i don't have a better name 
though. Maybe others have ideas.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to