tvalentyn commented on code in PR #27857:
URL: https://github.com/apache/beam/pull/27857#discussion_r1291719157


##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -243,70 +244,273 @@ def share_model_across_processes(self) -> bool:
     return False
 
 
+class _ModelManager:
+  """
+  A class for efficiently managing copies of multiple models. Will load a
+  single copy of each model into a multi_process_shared object and then
+  return a lookup key for that object. Optionally takes in a max_models
+  parameter, if that is set it will only hold that many models in memory at
+  once before evicting one (using LRU logic).
+  """
+  def __init__(
+      self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None):
+    """
+    Args:
+      mh_map: A map from keys to model handlers which can be used to load a
+        model.
+      max_models: The maximum number of models to load at any given time
+        before evicting 1 from memory (using LRU logic). Leave as None to
+        allow unlimited models.
+    """
+    self._max_models = max_models
+    self._mh_map: Dict[str, ModelHandler] = mh_map
+    self._proxy_map: Dict[str, str] = {}
+    self._tag_map: Dict[
+        str, multi_process_shared.MultiProcessShared] = OrderedDict()
+
+  def load(self, key: str) -> str:
+    """
+    Loads the appropriate model for the given key into memory.
+    Args:
+      key: the key associated with the model we'd like to load.
+    Returns:
+      the tag we can use to access the model using multi_process_shared.py.
+    """
+    # Map the key for a model to a unique tag that will persist until the model
+    # is released. This needs to be unique between releasing/reacquiring th
+    # model because otherwise the ProxyManager will try to reuse the model that
+    # has been released and deleted.
+    if key in self._tag_map:
+      self._tag_map.move_to_end(key)
+    else:
+      self._tag_map[key] = uuid.uuid4().hex
+
+    tag = self._tag_map[key]
+    mh = self._mh_map[key]
+
+    if self._max_models is not None and self._max_models < len(self._tag_map):
+      # If we're about to exceed our LRU size, release the last used model.
+      tag_to_remove = self._tag_map.popitem(last=False)[1]
+      shared_handle, model_to_remove = self._proxy_map[tag_to_remove]
+      shared_handle.release(model_to_remove)
+
+    # Load the new model
+    shared_handle = multi_process_shared.MultiProcessShared(
+        mh.load_model, tag=tag)
+    model_reference = shared_handle.acquire()
+    self._proxy_map[tag] = (shared_handle, model_reference)
+
+    return tag
+
+  def increment_max_models(self, increment: int):
+    """
+    Increments the number of models that this instance of a _ModelManager is
+    able to hold.
+    Args:
+      increment: the amount by which we are incrementing the number of models.
+    """
+    if self._max_models is None:
+      raise ValueError(
+          "Cannot increment max_models if self._max_models is None (unlimited" 
+
+          " models mode).")
+    self._max_models += increment
+
+
+# Use a dataclass instead of named tuple because NamedTuples and generics don't
+# mix well across the board for all versions:
+# https://github.com/python/typing/issues/653
+class KeyMhMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]):
+  def __init__(
+      self, keys: List[KeyT], mh: ModelHandler[ExampleT, PredictionT, ModelT]):
+    self.keys = keys
+    self.mh = mh
+
+
 class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
                         ModelHandler[Tuple[KeyT, ExampleT],
                                      Tuple[KeyT, PredictionT],
-                                     ModelT]):
-  def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]):
+                                     Union[ModelT, _ModelManager]]):
+  def __init__(
+      self,
+      unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT],
+                     List[KeyMhMapping[KeyT, ExampleT, PredictionT, ModelT]]]):
     """A ModelHandler that takes keyed examples and returns keyed predictions.
 
     For example, if the original model is used with RunInference to take a
     PCollection[E] to a PCollection[P], this ModelHandler would take a
     PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], making it possible
-    to use the key to associate the outputs with the inputs.
+    to use the key to associate the outputs with the inputs. KeyedModelHandler
+    is able to accept either a single unkeyed ModelHandler or many different
+    model handlers corresponding to the keys for which that ModelHandler should
+    be used. For example, the following configuration could be used to map keys
+    1-3 to ModelHandler1 and keys 4-5 to ModelHandler2:
+
+        k1 = ['k1', 'k2', 'k3']
+        k2 = ['k4', 'k5']
+        KeyedModelHandler([KeyMhMapping(k1, mh1), KeyMhMapping(k2, mh2)])
+
+    Note that a single copy of each of these models may all be held in memory
+    at the same time; be careful not to load too many large models or your
+    pipeline may cause Out of Memory exceptions.
 
     Args:
-      unkeyed: An implementation of ModelHandler that does not require keys.
+      unkeyed: Either (a) an implementation of ModelHandler that does not
+        require keys or (b) a list of KeyMhMappings mapping lists of keys to
+        unkeyed ModelHandlers.
     """
-    if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()):
-      raise Exception(
-          'Cannot make make an unkeyed model handler with pre or '
-          'postprocessing functions defined into a keyed model handler. All '
-          'pre/postprocessing functions must be defined on the outer model'
-          'handler.')
-    self._unkeyed = unkeyed
-    self._env_vars = unkeyed._env_vars
-
-  def load_model(self) -> ModelT:
-    return self._unkeyed.load_model()
+    self._many_models = isinstance(unkeyed, list)
+    if not self._many_models:
+      if len(unkeyed.get_preprocess_fns()) or len(
+          unkeyed.get_postprocess_fns()):
+        raise Exception(
+            'Cannot make make an unkeyed model handler with pre or '
+            'postprocessing functions defined into a keyed model handler. All '
+            'pre/postprocessing functions must be defined on the outer model'
+            'handler.')
+      self._env_vars = unkeyed._env_vars
+      self._unkeyed = unkeyed
+      return
+
+    self._id_to_mh_map: Dict[str, ModelHandler[ExampleT, PredictionT,
+                                               ModelT]] = {}
+    self._key_to_id_map: Dict[str, str] = {}
+    for mh_tuple in unkeyed:
+      mh = mh_tuple.mh
+      keys = mh_tuple.keys
+      if len(mh.get_preprocess_fns()) or len(mh.get_postprocess_fns()):
+        raise ValueError(
+            'Cannot use an unkeyed model handler with pre or '
+            'postprocessing functions defined in a keyed model handler. All '
+            'pre/postprocessing functions must be defined on the outer model'
+            'handler.')
+      hints = mh.get_resource_hints()
+      if len(hints) > 0:
+        logging.warning(
+            'mh %s defines the following resource hints %s which will '
+            'be ignored. Resource hints are not respected when more than one '
+            'model handler is used in a KeyedModelHandler. If you would like '
+            'to specify resource hints, you can do so by overriding this '
+            'KeyedModelHandler and defining the get_resource_hints function.',
+            mh,
+            hints)
+      batch_kwargs = mh.batch_elements_kwargs()
+      if len(hints) > 0:
+        logging.warning(
+            'mh %s defines the following batching kwargs %s '
+            'which will be ignored. Batching kwargs are not respected when '
+            'more than one model handler is used in a KeyedModelHandler. If '
+            'you would like to specify resource hints, you can do so by '
+            'overriding this KeyedModelHandler and defining the '
+            'batch_elements_kwargs function.',
+            hints,
+            batch_kwargs)
+      env_vars = mh._env_vars
+      if len(hints) > 0:
+        logging.warning(
+            'mh %s defines the following _env_vars %s '
+            'which will be ignored. _env_vars are not respected when '
+            'more than one model handler is used in a KeyedModelHandler. '
+            'If you need env vars set at inference time, you can do so with '
+            'a custom inference function.',
+            mh,
+            env_vars)
+
+      if len(keys) == 0:
+        raise ValueError(
+            f'Empty list maps to model handler {mh}. All model handlers must '
+            'have one or more associated keys.')
+      self._id_to_mh_map[keys[0]] = mh
+      for key in keys:
+        if key in self._key_to_id_map:
+          raise ValueError(
+              f'key {key} maps to multiple model handlers. All keys must map '
+              'to exactly one model handler.')
+        self._key_to_id_map[key] = keys[0]
+
+  def load_model(self) -> Union[ModelT, _ModelManager]:
+    if not self._many_models:
+      return self._unkeyed.load_model()
+    return _ModelManager(self._id_to_mh_map)
 
   def run_inference(
       self,
       batch: Sequence[Tuple[KeyT, ExampleT]],
-      model: ModelT,
+      model: Union[ModelT, _ModelManager],
       inference_args: Optional[Dict[str, Any]] = None
   ) -> Iterable[Tuple[KeyT, PredictionT]]:
-    keys, unkeyed_batch = zip(*batch)
-    return zip(
-        keys, self._unkeyed.run_inference(unkeyed_batch, model, 
inference_args))
+    if not self._many_models:
+      keys, unkeyed_batch = zip(*batch)
+      return zip(
+          keys,
+          self._unkeyed.run_inference(unkeyed_batch, model, inference_args))
+
+    batch_by_key = {}
+    for key, example in batch:
+      if key not in batch_by_key:
+        batch_by_key[key] = []
+      batch_by_key[key].append(example)
+
+    predictions = []
+    for key, unkeyed_batches in batch_by_key.items():
+      id = self._key_to_id_map[key]
+      mh = self._id_to_mh_map[id]
+      keyed_model_tag = model.load(id)
+      keyed_model_shared_handle = multi_process_shared.MultiProcessShared(
+          mh.load_model, tag=keyed_model_tag)
+      keyed_model = keyed_model_shared_handle.acquire()
+      for inf in mh.run_inference(unkeyed_batches, keyed_model, 
inference_args):
+        predictions.append((key, inf))
+      keyed_model_shared_handle.release(keyed_model)
+
+    return predictions
 
   def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int:
-    keys, unkeyed_batch = zip(*batch)
-    return len(pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
+    if not self._many_models:
+      keys, unkeyed_batch = zip(*batch)
+      return len(
+          pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch)
+    return len(pickle.dumps(batch))
 
   def get_metrics_namespace(self) -> str:
-    return self._unkeyed.get_metrics_namespace()
+    if not self._many_models:
+      return self._unkeyed.get_metrics_namespace()
+    return 'BeamML_KeyedModels'
 
   def get_resource_hints(self):
-    return self._unkeyed.get_resource_hints()
+    if not self._many_models:
+      return self._unkeyed.get_resource_hints()
+    return {}
 
   def batch_elements_kwargs(self):
-    return self._unkeyed.batch_elements_kwargs()
+    if not self._many_models:
+      return self._unkeyed.batch_elements_kwargs()
+    return {}
 
   def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]):
-    return self._unkeyed.validate_inference_args(inference_args)
+    if not self._many_models:
+      return self._unkeyed.validate_inference_args(inference_args)
+    for mh in self._id_to_mh_map.values():
+      mh.validate_inference_args(inference_args)
 
   def update_model_path(self, model_path: Optional[str] = None):
-    return self._unkeyed.update_model_path(model_path=model_path)
+    if not self._many_models:
+      return self._unkeyed.update_model_path(model_path=model_path)
+    if model_path is not None:
+      raise RuntimeError(
+          'Model updates are currently not supported for ' +
+          'KeyedModelHandlers with multiple different per-key ' +
+          'ModelHandlers.')
 
   def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
-    return self._unkeyed.get_preprocess_fns()
+    return []

Review Comment:
   note that this is also a default implementation in the superclass.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to