tvalentyn commented on code in PR #31052:
URL: https://github.com/apache/beam/pull/31052#discussion_r1575378402
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -795,6 +802,21 @@ def share_model_across_processes(self) -> bool:
return self._unkeyed.share_model_across_processes()
return True
+ def max_shared_model_copies(self) -> int:
+ if self._single_model:
+ return self._unkeyed.max_shared_model_copies()
+ for mh in self._id_to_mh_map.values():
+ if mh.max_shared_model_copies() != 1:
+ raise ValueError(
+ 'KeyedModelHandler cannot map records to multiple '
+ 'models if one or more of its ModelHandlers '
+ 'require multiple model copies (set via'
Review Comment:
```suggestion
'require multiple model copies (set via '
```
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1434,19 +1482,28 @@ def load():
if isinstance(side_input_model_path, str) and side_input_model_path != '':
model_tag = side_input_model_path
if self._model_handler.share_model_across_processes():
- model = multi_process_shared.MultiProcessShared(
- load, tag=model_tag, always_proxy=True).acquire()
+ # TODO - update this to populate a list of models of configurable length
Review Comment:
Not sure I follow the TODO. Should we link a GH issue and include necessary
details there?
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1434,19 +1518,28 @@ def load():
if isinstance(side_input_model_path, str) and side_input_model_path != '':
model_tag = side_input_model_path
if self._model_handler.share_model_across_processes():
- model = multi_process_shared.MultiProcessShared(
- load, tag=model_tag, always_proxy=True).acquire()
+ # TODO - update this to populate a list of models of configurable length
+ models = []
+ for i in range(self._model_handler.max_shared_model_copies()):
+ models.append(
+ multi_process_shared.MultiProcessShared(
+ load, tag=f'{model_tag}{i}', always_proxy=True).acquire())
+ model_wrapper = _CrossProcessModelWrapper(models, model_tag)
else:
model = self._shared_model_handle.acquire(load, tag=model_tag)
+ model_wrapper = _CrossProcessModelWrapper([model], model_tag)
# since shared_model_handle is shared across threads, the model path
# might not get updated in the model handler
# because we directly get cached weak ref model from shared cache, instead
# of calling load(). For sanity check, call update_model_path again.
if isinstance(side_input_model_path, str):
self._model_handler.update_model_path(side_input_model_path)
else:
- self._model_handler.update_model_paths(self._model,
side_input_model_path)
- return model
+ if self._model is not None:
+ models = self._model.all_models()
+ for m in models:
+ self._model_handler.update_model_paths(m, side_input_model_path)
+ return model_wrapper
Review Comment:
Can we have a typehint for what _load_model returns or at least describe
the return type in a docstring?
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1378,6 +1421,45 @@ def update(
self._inference_request_batch_byte_size.update(examples_byte_size)
+class _ModelRoutingStrategy():
+ """A class meant to sit in a shared location for mapping incoming batches to
+ different models. Currently only supports round-robin, but can be extended
+ to support other protocols if needed.
+ """
+ def __init__(self):
+ self._cur_index = 0
+
+ def next_model_index(self, num_models):
+ self._cur_index = (self._cur_index + 1) % num_models
+ return self._cur_index
+
+
+class _CrossProcessModelWrapper():
+ """A router class to map incoming calls to the correct model.
+
+ This allows us to round robin calls to models sitting in different
+ processes so that we can more efficiently use resources (e.g. GPUs).
+ """
+ def __init__(self, models: List[Any], model_tag: str):
+ self.models = models
+ if len(models) > 0:
Review Comment:
should this be an `assert` ?
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1416,8 +1498,10 @@ def load():
if isinstance(side_input_model_path, str):
self._model_handler.update_model_path(side_input_model_path)
else:
- self._model_handler.update_model_paths(
- self._model, side_input_model_path)
+ if self._model is not None:
Review Comment:
what is the typehint for self._model ? Can we specify it in the constructor
where we assign it? in which case can it be None?
##########
sdks/python/apache_beam/ml/inference/base_test.py:
##########
@@ -79,11 +90,15 @@ def __init__(
self._env_vars = kwargs.get('env_vars', {})
self._multi_process_shared = multi_process_shared
self._state = state
+ self._incrementing = incrementing
+ self._max_copies = max_copies
self._num_bytes_per_element = num_bytes_per_element
def load_model(self):
if self._fake_clock:
self._fake_clock.current_time_ns += 500_000_000 # 500ms
+ if self._incrementing:
+ return FakeIncrementingModel()
Review Comment:
let's add an assertion that incrementing and stateful settings are not used
at the same time
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -315,6 +315,13 @@ def share_model_across_processes(self) -> bool:
https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html"""
return False
+ def max_shared_model_copies(self) -> int:
Review Comment:
possible alternative names: model_copies, model_copies_to_load (asked a
related comment in the doc)
##########
sdks/python/apache_beam/ml/inference/huggingface_inference.py:
##########
@@ -257,6 +258,9 @@ def __init__(
memory pressure if you load multiple copies. Given a model that
consumes N memory and a machine with W cores and M memory, you should
set this to True if N*W > M.
+ model_copies: The exact number of models that you would like loaded
Review Comment:
TODO(self) wording re: becomes a no-op
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -795,6 +802,21 @@ def share_model_across_processes(self) -> bool:
return self._unkeyed.share_model_across_processes()
return True
+ def max_shared_model_copies(self) -> int:
+ if self._single_model:
+ return self._unkeyed.max_shared_model_copies()
+ for mh in self._id_to_mh_map.values():
+ if mh.max_shared_model_copies() != 1:
+ raise ValueError(
+ 'KeyedModelHandler cannot map records to multiple '
+ 'models if one or more of its ModelHandlers '
+ 'require multiple model copies (set via'
+ 'max_shared_model_copies). To fix, verify that each '
Review Comment:
```suggestion
'model_copies). To fix, verify that each '
```
##########
sdks/python/apache_beam/ml/inference/base.py:
##########
@@ -1434,19 +1518,28 @@ def load():
if isinstance(side_input_model_path, str) and side_input_model_path != '':
model_tag = side_input_model_path
if self._model_handler.share_model_across_processes():
- model = multi_process_shared.MultiProcessShared(
- load, tag=model_tag, always_proxy=True).acquire()
+ # TODO - update this to populate a list of models of configurable length
+ models = []
+ for i in range(self._model_handler.max_shared_model_copies()):
+ models.append(
+ multi_process_shared.MultiProcessShared(
+ load, tag=f'{model_tag}{i}', always_proxy=True).acquire())
+ model_wrapper = _CrossProcessModelWrapper(models, model_tag)
else:
model = self._shared_model_handle.acquire(load, tag=model_tag)
+ model_wrapper = _CrossProcessModelWrapper([model], model_tag)
Review Comment:
Why do we need _CrossProcessModelWrapper when share_model_across_processes
is not used? should we use some single-process model wrapper stub instead?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]