tvalentyn commented on code in PR #27857:
URL: https://github.com/apache/beam/pull/27857#discussion_r1291718410


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -243,70 +244,273 @@ def share_model_across_processes(self) -> bool:
     return False
 
 
+class _ModelManager:
+  """
+  A class for efficiently managing copies of multiple models. Will load a
+  single copy of each model into a multi_process_shared object and then
+  return a lookup key for that object. Optionally takes in a max_models
+  parameter, if that is set it will only hold that many models in memory at
+  once before evicting one (using LRU logic).
+  """
+  def __init__(
+      self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None):
+    """
+    Args:
+      mh_map: A map from keys to model handlers which can be used to load a
+        model.
+      max_models: The maximum number of models to load at any given time
+        before evicting 1 from memory (using LRU logic). Leave as None to
+        allow unlimited models.
+    """
+    self._max_models = max_models
+    self._mh_map: Dict[str, ModelHandler] = mh_map
+    self._proxy_map: Dict[str, str] = {}
+    self._tag_map: Dict[
+        str, multi_process_shared.MultiProcessShared] = OrderedDict()
+
+  def load(self, key: str) -> str:
+    """
+    Loads the appropriate model for the given key into memory.
+    Args:
+      key: the key associated with the model we'd like to load.
+    Returns:
+      the tag we can use to access the model using multi_process_shared.py.
+    """
+    # Map the key for a model to a unique tag that will persist until the model
+    # is released. This needs to be unique between releasing/reacquiring th
+    # model because otherwise the ProxyManager will try to reuse the model that
+    # has been released and deleted.
+    if key in self._tag_map:
+      self._tag_map.move_to_end(key)
+    else:
+      self._tag_map[key] = uuid.uuid4().hex
+
+    tag = self._tag_map[key]
+    mh = self._mh_map[key]
+
+    if self._max_models is not None and self._max_models < len(self._tag_map):
+      # If we're about to exceed our LRU size, release the last used model.
+      tag_to_remove = self._tag_map.popitem(last=False)[1]
+      shared_handle, model_to_remove = self._proxy_map[tag_to_remove]
+      shared_handle.release(model_to_remove)
+
+    # Load the new model
+    shared_handle = multi_process_shared.MultiProcessShared(
+        mh.load_model, tag=tag)
+    model_reference = shared_handle.acquire()
+    self._proxy_map[tag] = (shared_handle, model_reference)
+
+    return tag
+
+  def increment_max_models(self, increment: int):
+    """
+    Increments the number of models that this instance of a _ModelManager is
+    able to hold.
+    Args:
+      increment: the amount by which we are incrementing the number of models.
+    """
+    if self._max_models is None:
+      raise ValueError(
+          "Cannot increment max_models if self._max_models is None (unlimited" 
+
+          " models mode).")
+    self._max_models += increment
+
+
+# Use a dataclass instead of named tuple because NamedTuples and generics don't
+# mix well across the board for all versions:
+# https://github.com/python/typing/issues/653
+class KeyMhMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]):
+  def __init__(
+      self, keys: List[KeyT], mh: ModelHandler[ExampleT, PredictionT, ModelT]):
+    self.keys = keys
+    self.mh = mh
+
+
 class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
                         ModelHandler[Tuple[KeyT, ExampleT],
                                      Tuple[KeyT, PredictionT],
-                                     ModelT]):
-  def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]):
+                                     Union[ModelT, _ModelManager]]):
+  def __init__(
+      self,
+      unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT],
+                     List[KeyMhMapping[KeyT, ExampleT, PredictionT, ModelT]]]):
     """A ModelHandler that takes keyed examples and returns keyed predictions.
 
     For example, if the original model is used with RunInference to take a
     PCollection[E] to a PCollection[P], this ModelHandler would take a
     PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], making it possible
-    to use the key to associate the outputs with the inputs.
+    to use the key to associate the outputs with the inputs. KeyedModelHandler
+    is able to accept either a single unkeyed ModelHandler or many different
+    model handlers corresponding to the keys for which that ModelHandler should
+    be used. For example, the following configuration could be used to map keys
+    1-3 to ModelHandler1 and keys 4-5 to ModelHandler2:
+
+        k1 = ['k1', 'k2', 'k3']
+        k2 = ['k4', 'k5']
+        KeyedModelHandler([KeyMhMapping(k1, mh1), KeyMhMapping(k2, mh2)])
+
+    Note that a single copy of each of these models may all be held in memory
+    at the same time; be careful not to load too many large models or your

Review Comment:
   consider adding a TODO for reference in case someone comes to read the 
docstring before external reference is added.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to