This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch release-2.56.0
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/release-2.56.0 by this push:
     new fe00df6b924 Add ability to load multiple copies of a model across 
processes (#31052) (#31104)
fe00df6b924 is described below

commit fe00df6b92433e7ab8226ecdc03bdf4164b46e22
Author: Danny McCormick <dannymccorm...@google.com>
AuthorDate: Thu Apr 25 20:26:51 2024 -0400

    Add ability to load multiple copies of a model across processes (#31052) 
(#31104)
    
    * Add ability to load multiple copies of a model across processes
    
    * push changes I had locally not remotely
    
    * Lint
    
    * naming + lint
    
    * Changes from feedback
---
 sdks/python/apache_beam/ml/inference/base.py       | 110 +++++++++++++++++++--
 sdks/python/apache_beam/ml/inference/base_test.py  |  71 +++++++++++++
 .../ml/inference/huggingface_inference.py          |  36 +++++--
 .../apache_beam/ml/inference/onnx_inference.py     |  12 ++-
 .../apache_beam/ml/inference/pytorch_inference.py  |  24 ++++-
 .../apache_beam/ml/inference/sklearn_inference.py  |  24 ++++-
 .../ml/inference/tensorflow_inference.py           |  24 ++++-
 .../apache_beam/ml/inference/tensorrt_inference.py |  12 ++-
 8 files changed, 283 insertions(+), 30 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index 587b3060c23..6fe2d5acc5c 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -315,6 +315,13 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
     
https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html""";
     return False
 
+  def model_copies(self) -> int:
+    """Returns the maximum number of model copies that should be loaded at one
+    time. This only impacts model handlers that are using
+    share_model_across_processes to share their model across processes instead
+    of being loaded per process."""
+    return 1
+
   def override_metrics(self, metrics_namespace: str = '') -> bool:
     """Returns a boolean representing whether or not a model handler will
     override metrics reporting. If True, RunInference will not report any
@@ -795,6 +802,21 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
       return self._unkeyed.share_model_across_processes()
     return True
 
+  def model_copies(self) -> int:
+    if self._single_model:
+      return self._unkeyed.model_copies()
+    for mh in self._id_to_mh_map.values():
+      if mh.model_copies() != 1:
+        raise ValueError(
+            'KeyedModelHandler cannot map records to multiple '
+            'models if one or more of its ModelHandlers '
+            'require multiple model copies (set via '
+            'model_copies). To fix, verify that each '
+            'ModelHandler is not set to load multiple copies of '
+            'its model.')
+
+    return 1
+
   def override_metrics(self, metrics_namespace: str = '') -> bool:
     if self._single_model:
       return self._unkeyed.override_metrics(metrics_namespace)
@@ -902,6 +924,9 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
   def share_model_across_processes(self) -> bool:
     return self._unkeyed.share_model_across_processes()
 
+  def model_copies(self) -> int:
+    return self._unkeyed.model_copies()
+
 
 class _PrebatchedModelHandler(Generic[ExampleT, PredictionT, ModelT],
                               ModelHandler[Sequence[ExampleT],
@@ -952,6 +977,12 @@ class _PrebatchedModelHandler(Generic[ExampleT, 
PredictionT, ModelT],
   def should_skip_batching(self) -> bool:
     return True
 
+  def share_model_across_processes(self) -> bool:
+    return self._base.share_model_across_processes()
+
+  def model_copies(self) -> int:
+    return self._base.model_copies()
+
   def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
     return self._base.get_postprocess_fns()
 
@@ -1012,6 +1043,12 @@ class _PreProcessingModelHandler(Generic[ExampleT,
   def should_skip_batching(self) -> bool:
     return self._base.should_skip_batching()
 
+  def share_model_across_processes(self) -> bool:
+    return self._base.share_model_across_processes()
+
+  def model_copies(self) -> int:
+    return self._base.model_copies()
+
   def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
     return self._base.get_postprocess_fns()
 
@@ -1071,6 +1108,12 @@ class _PostProcessingModelHandler(Generic[ExampleT,
   def should_skip_batching(self) -> bool:
     return self._base.should_skip_batching()
 
+  def share_model_across_processes(self) -> bool:
+    return self._base.share_model_across_processes()
+
+  def model_copies(self) -> int:
+    return self._base.model_copies()
+
   def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
     return self._base.get_postprocess_fns() + [self._postprocess_fn]
 
@@ -1378,6 +1421,45 @@ class _MetricsCollector:
     self._inference_request_batch_byte_size.update(examples_byte_size)
 
 
+class _ModelRoutingStrategy():
+  """A class meant to sit in a shared location for mapping incoming batches to
+  different models. Currently only supports round-robin, but can be extended
+  to support other protocols if needed.
+  """
+  def __init__(self):
+    self._cur_index = 0
+
+  def next_model_index(self, num_models):
+    self._cur_index = (self._cur_index + 1) % num_models
+    return self._cur_index
+
+
+class _SharedModelWrapper():
+  """A router class to map incoming calls to the correct model.
+  
+    This allows us to round robin calls to models sitting in different
+    processes so that we can more efficiently use resources (e.g. GPUs).
+  """
+  def __init__(self, models: List[Any], model_tag: str):
+    self.models = models
+    if len(models) > 1:
+      self.model_router = multi_process_shared.MultiProcessShared(
+          lambda: _ModelRoutingStrategy(),
+          tag=f'{model_tag}_counter',
+          always_proxy=True).acquire()
+
+  def next_model(self):
+    if len(self.models) == 1:
+      # Short circuit if there's no routing strategy needed in order to
+      # avoid the cross-process call
+      return self.models[0]
+
+    return self.models[self.model_router.next_model_index(len(self.models))]
+
+  def all_models(self):
+    return self.models
+
+
 class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]):
   def __init__(
       self,
@@ -1408,7 +1490,8 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, 
PredictionT]):
   def _load_model(
       self,
       side_input_model_path: Optional[Union[str,
-                                            List[KeyModelPathMapping]]] = 
None):
+                                            List[KeyModelPathMapping]]] = None
+  ) -> _SharedModelWrapper:
     def load():
       """Function for constructing shared LoadedModel."""
       memory_before = _get_current_process_memory_in_bytes()
@@ -1416,8 +1499,10 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, 
PredictionT]):
       if isinstance(side_input_model_path, str):
         self._model_handler.update_model_path(side_input_model_path)
       else:
-        self._model_handler.update_model_paths(
-            self._model, side_input_model_path)
+        if self._model is not None:
+          models = self._model.all_models()
+          for m in models:
+            self._model_handler.update_model_paths(m, side_input_model_path)
       model = self._model_handler.load_model()
       end_time = _to_milliseconds(self._clock.time_ns())
       memory_after = _get_current_process_memory_in_bytes()
@@ -1434,10 +1519,15 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, 
PredictionT]):
     if isinstance(side_input_model_path, str) and side_input_model_path != '':
       model_tag = side_input_model_path
     if self._model_handler.share_model_across_processes():
-      model = multi_process_shared.MultiProcessShared(
-          load, tag=model_tag, always_proxy=True).acquire()
+      models = []
+      for i in range(self._model_handler.model_copies()):
+        models.append(
+            multi_process_shared.MultiProcessShared(
+                load, tag=f'{model_tag}{i}', always_proxy=True).acquire())
+      model_wrapper = _SharedModelWrapper(models, model_tag)
     else:
       model = self._shared_model_handle.acquire(load, tag=model_tag)
+      model_wrapper = _SharedModelWrapper([model], model_tag)
     # since shared_model_handle is shared across threads, the model path
     # might not get updated in the model handler
     # because we directly get cached weak ref model from shared cache, instead
@@ -1445,8 +1535,11 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, 
PredictionT]):
     if isinstance(side_input_model_path, str):
       self._model_handler.update_model_path(side_input_model_path)
     else:
-      self._model_handler.update_model_paths(self._model, 
side_input_model_path)
-    return model
+      if self._model is not None:
+        models = self._model.all_models()
+        for m in models:
+          self._model_handler.update_model_paths(m, side_input_model_path)
+    return model_wrapper
 
   def get_metrics_collector(self, prefix: str = ''):
     """
@@ -1476,8 +1569,9 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, 
PredictionT]):
   def _run_inference(self, batch, inference_args):
     start_time = _to_microseconds(self._clock.time_ns())
     try:
+      model = self._model.next_model()
       result_generator = self._model_handler.run_inference(
-          batch, self._model, inference_args)
+          batch, model, inference_args)
     except BaseException as e:
       if self._metrics_collector:
         self._metrics_collector.failed_batches_counter.inc()
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py 
b/sdks/python/apache_beam/ml/inference/base_test.py
index d237aee1ce9..ec1664f494c 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -63,6 +63,15 @@ class FakeStatefulModel:
     self._state += amount
 
 
+class FakeIncrementingModel:
+  def __init__(self):
+    self._state = 0
+
+  def predict(self, example: int) -> int:
+    self._state += 1
+    return self._state
+
+
 class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
   def __init__(
       self,
@@ -71,6 +80,8 @@ class FakeModelHandler(base.ModelHandler[int, int, 
FakeModel]):
       max_batch_size=9999,
       multi_process_shared=False,
       state=None,
+      incrementing=False,
+      max_copies=1,
       num_bytes_per_element=None,
       **kwargs):
     self._fake_clock = clock
@@ -79,11 +90,16 @@ class FakeModelHandler(base.ModelHandler[int, int, 
FakeModel]):
     self._env_vars = kwargs.get('env_vars', {})
     self._multi_process_shared = multi_process_shared
     self._state = state
+    self._incrementing = incrementing
+    self._max_copies = max_copies
     self._num_bytes_per_element = num_bytes_per_element
 
   def load_model(self):
+    assert (not self._incrementing or self._state is None)
     if self._fake_clock:
       self._fake_clock.current_time_ns += 500_000_000  # 500ms
+    if self._incrementing:
+      return FakeIncrementingModel()
     if self._state is not None:
       return FakeStatefulModel(self._state)
     return FakeModel()
@@ -116,6 +132,9 @@ class FakeModelHandler(base.ModelHandler[int, int, 
FakeModel]):
   def share_model_across_processes(self):
     return self._multi_process_shared
 
+  def model_copies(self):
+    return self._max_copies
+
   def get_num_bytes(self, batch: Sequence[int]) -> int:
     if self._num_bytes_per_element:
       return self._num_bytes_per_element * len(batch)
@@ -258,6 +277,58 @@ class RunInferenceBaseTest(unittest.TestCase):
           FakeModelHandler(multi_process_shared=True))
       assert_that(actual, equal_to(expected), label='assert:inferences')
 
+  def test_run_inference_impl_simple_examples_multi_process_shared_multi_copy(
+      self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      expected = [example + 1 for example in examples]
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      actual = pcoll | base.RunInference(
+          FakeModelHandler(multi_process_shared=True, max_copies=4))
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+  def test_run_inference_impl_multi_process_shared_incrementing_multi_copy(
+      self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10]
+      expected = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      actual = pcoll | base.RunInference(
+          FakeModelHandler(
+              multi_process_shared=True,
+              max_copies=4,
+              incrementing=True,
+              max_batch_size=1))
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+  def test_run_inference_impl_mps_nobatch_incrementing_multi_copy(self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10]
+      expected = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]
+      batched_examples = [[example] for example in examples]
+      pcoll = pipeline | 'start' >> beam.Create(batched_examples)
+      actual = pcoll | base.RunInference(
+          FakeModelHandler(
+              multi_process_shared=True, max_copies=4,
+              incrementing=True).with_no_batching())
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+  def test_run_inference_impl_keyed_mps_incrementing_multi_copy(self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10, 1, 5, 3, 10]
+      keyed_examples = [('abc', example) for example in examples]
+      expected = [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4]
+      keyed_expected = [('abc', val) for val in expected]
+      pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
+      actual = pcoll | base.RunInference(
+          base.KeyedModelHandler(
+              FakeModelHandler(
+                  multi_process_shared=True,
+                  max_copies=4,
+                  incrementing=True,
+                  max_batch_size=1)))
+      assert_that(actual, equal_to(keyed_expected), label='assert:inferences')
+
   def test_run_inference_impl_with_keyed_examples(self):
     with TestPipeline() as pipeline:
       examples = [1, 5, 3, 10]
diff --git a/sdks/python/apache_beam/ml/inference/huggingface_inference.py 
b/sdks/python/apache_beam/ml/inference/huggingface_inference.py
index 25367d22eaa..28e24d920fb 100644
--- a/sdks/python/apache_beam/ml/inference/huggingface_inference.py
+++ b/sdks/python/apache_beam/ml/inference/huggingface_inference.py
@@ -225,6 +225,7 @@ class 
HuggingFaceModelHandlerKeyedTensor(ModelHandler[Dict[str,
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
+      model_copies: Optional[int] = None,
       **kwargs):
     """
     Implementation of the ModelHandler interface for HuggingFace with
@@ -257,6 +258,9 @@ class 
HuggingFaceModelHandlerKeyedTensor(ModelHandler[Dict[str,
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
         set this to True if N*W > M.
+      model_copies: The exact number of models that you would like loaded
+        onto your machine. This can be useful if you exactly know your CPU or
+        GPU capacity and want to maximize resource utilization.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
 
@@ -276,7 +280,8 @@ class 
HuggingFaceModelHandlerKeyedTensor(ModelHandler[Dict[str,
       self._batching_kwargs["max_batch_size"] = max_batch_size
     if max_batch_duration_secs is not None:
       self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
-    self._large_model = large_model
+    self._share_across_processes = large_model or (model_copies is not None)
+    self._model_copies = model_copies or 1
     self._framework = framework
 
     _validate_constructor_args(
@@ -350,7 +355,10 @@ class 
HuggingFaceModelHandlerKeyedTensor(ModelHandler[Dict[str,
     return self._batching_kwargs
 
   def share_model_across_processes(self) -> bool:
-    return self._large_model
+    return self._share_across_processes
+
+  def model_copies(self) -> int:
+    return self._model_copies
 
   def get_metrics_namespace(self) -> str:
     """
@@ -405,6 +413,7 @@ class 
HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor],
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
+      model_copies: Optional[int] = None,
       **kwargs):
     """
     Implementation of the ModelHandler interface for HuggingFace with
@@ -437,6 +446,9 @@ class 
HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor],
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
         set this to True if N*W > M.
+      model_copies: The exact number of models that you would like loaded
+        onto your machine. This can be useful if you exactly know your CPU or
+        GPU capacity and want to maximize resource utilization.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
 
@@ -456,7 +468,8 @@ class 
HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor],
       self._batching_kwargs["max_batch_size"] = max_batch_size
     if max_batch_duration_secs is not None:
       self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
-    self._large_model = large_model
+    self._share_across_processes = large_model or (model_copies is not None)
+    self._model_copies = model_copies or 1
     self._framework = ""
 
     _validate_constructor_args(
@@ -537,7 +550,10 @@ class 
HuggingFaceModelHandlerTensor(ModelHandler[Union[tf.Tensor, torch.Tensor],
     return self._batching_kwargs
 
   def share_model_across_processes(self) -> bool:
-    return self._large_model
+    return self._share_across_processes
+
+  def model_copies(self) -> int:
+    return self._model_copies
 
   def get_metrics_namespace(self) -> str:
     """
@@ -578,6 +594,7 @@ class HuggingFacePipelineModelHandler(ModelHandler[str,
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
+      model_copies: Optional[int] = None,
       **kwargs):
     """
     Implementation of the ModelHandler interface for Hugging Face Pipelines.
@@ -618,6 +635,9 @@ class HuggingFacePipelineModelHandler(ModelHandler[str,
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
         set this to True if N*W > M.
+      model_copies: The exact number of models that you would like loaded
+        onto your machine. This can be useful if you exactly know your CPU or
+        GPU capacity and want to maximize resource utilization.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
 
@@ -637,7 +657,8 @@ class HuggingFacePipelineModelHandler(ModelHandler[str,
       self._batching_kwargs['max_batch_size'] = max_batch_size
     if max_batch_duration_secs is not None:
       self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
-    self._large_model = large_model
+    self._share_across_processes = large_model or (model_copies is not None)
+    self._model_copies = model_copies or 1
 
     # Check if the device is specified twice. If true then the device parameter
     # of model handler is overridden.
@@ -718,7 +739,10 @@ class HuggingFacePipelineModelHandler(ModelHandler[str,
     return self._batching_kwargs
 
   def share_model_across_processes(self) -> bool:
-    return self._large_model
+    return self._share_across_processes
+
+  def model_copies(self) -> int:
+    return self._model_copies
 
   def get_metrics_namespace(self) -> str:
     """
diff --git a/sdks/python/apache_beam/ml/inference/onnx_inference.py 
b/sdks/python/apache_beam/ml/inference/onnx_inference.py
index f7b6c0115af..e7af114ad43 100644
--- a/sdks/python/apache_beam/ml/inference/onnx_inference.py
+++ b/sdks/python/apache_beam/ml/inference/onnx_inference.py
@@ -64,6 +64,7 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray,
       *,
       inference_fn: NumpyInferenceFn = default_numpy_inference_fn,
       large_model: bool = False,
+      model_copies: Optional[int] = None,
       min_batch_size: Optional[int] = None,
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
@@ -84,6 +85,9 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray,
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
         set this to True if N*W > M.
+      model_copies: The exact number of models that you would like loaded
+        onto your machine. This can be useful if you exactly know your CPU or
+        GPU capacity and want to maximize resource utilization.
       min_batch_size: the minimum batch size to use when batching inputs.
       max_batch_size: the maximum batch size to use when batching inputs.
       max_batch_duration_secs: the maximum amount of time to buffer a batch
@@ -97,7 +101,8 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray,
     self._provider_options = provider_options
     self._model_inference_fn = inference_fn
     self._env_vars = kwargs.get('env_vars', {})
-    self._large_model = large_model
+    self._share_across_processes = large_model or (model_copies is not None)
+    self._model_copies = model_copies or 1
     self._batching_kwargs = {}
     if min_batch_size is not None:
       self._batching_kwargs["min_batch_size"] = min_batch_size
@@ -157,7 +162,10 @@ class OnnxModelHandlerNumpy(ModelHandler[numpy.ndarray,
     return 'BeamML_Onnx'
 
   def share_model_across_processes(self) -> bool:
-    return self._large_model
+    return self._share_across_processes
+
+  def model_copies(self) -> int:
+    return self._model_copies
 
   def batch_elements_kwargs(self) -> Mapping[str, Any]:
     return self._batching_kwargs
diff --git a/sdks/python/apache_beam/ml/inference/pytorch_inference.py 
b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
index 480dc538195..9a89cba7243 100644
--- a/sdks/python/apache_beam/ml/inference/pytorch_inference.py
+++ b/sdks/python/apache_beam/ml/inference/pytorch_inference.py
@@ -195,6 +195,7 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
+      model_copies: Optional[int] = None,
       load_model_args: Optional[Dict[str, Any]] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for PyTorch.
@@ -234,6 +235,9 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
         set this to True if N*W > M.
+      model_copies: The exact number of models that you would like loaded
+        onto your machine. This can be useful if you exactly know your CPU or
+        GPU capacity and want to maximize resource utilization.
       load_model_args: a dictionary of parameters passed to the torch.load
         function to specify custom config for loading models.
       kwargs: 'env_vars' can be used to set environment variables
@@ -262,7 +266,8 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
     self._torch_script_model_path = torch_script_model_path
     self._load_model_args = load_model_args if load_model_args else {}
     self._env_vars = kwargs.get('env_vars', {})
-    self._large_model = large_model
+    self._share_across_processes = large_model or (model_copies is not None)
+    self._model_copies = model_copies or 1
 
     _validate_constructor_args(
         state_dict_path=self._state_dict_path,
@@ -344,7 +349,10 @@ class PytorchModelHandlerTensor(ModelHandler[torch.Tensor,
     return self._batching_kwargs
 
   def share_model_across_processes(self) -> bool:
-    return self._large_model
+    return self._share_across_processes
+
+  def model_copies(self) -> int:
+    return self._model_copies
 
 
 def default_keyed_tensor_inference_fn(
@@ -428,6 +436,7 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, 
torch.Tensor],
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
+      model_copies: Optional[int] = None,
       load_model_args: Optional[Dict[str, Any]] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for PyTorch.
@@ -472,6 +481,9 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, 
torch.Tensor],
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
         set this to True if N*W > M.
+      model_copies: The exact number of models that you would like loaded
+        onto your machine. This can be useful if you exactly know your CPU or
+        GPU capacity and want to maximize resource utilization.
       load_model_args: a dictionary of parameters passed to the torch.load
         function to specify custom config for loading models.
       kwargs: 'env_vars' can be used to set environment variables
@@ -500,7 +512,8 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, 
torch.Tensor],
     self._torch_script_model_path = torch_script_model_path
     self._load_model_args = load_model_args if load_model_args else {}
     self._env_vars = kwargs.get('env_vars', {})
-    self._large_model = large_model
+    self._share_across_processes = large_model or (model_copies is not None)
+    self._model_copies = model_copies or 1
 
     _validate_constructor_args(
         state_dict_path=self._state_dict_path,
@@ -584,4 +597,7 @@ class PytorchModelHandlerKeyedTensor(ModelHandler[Dict[str, 
torch.Tensor],
     return self._batching_kwargs
 
   def share_model_across_processes(self) -> bool:
-    return self._large_model
+    return self._share_across_processes
+
+  def model_copies(self) -> int:
+    return self._model_copies
diff --git a/sdks/python/apache_beam/ml/inference/sklearn_inference.py 
b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
index befeca7f33b..a29657968ea 100644
--- a/sdks/python/apache_beam/ml/inference/sklearn_inference.py
+++ b/sdks/python/apache_beam/ml/inference/sklearn_inference.py
@@ -92,6 +92,7 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
+      model_copies: Optional[int] = None,
       **kwargs):
     """ Implementation of the ModelHandler interface for scikit-learn
     using numpy arrays as input.
@@ -118,6 +119,9 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
         set this to True if N*W > M.
+      model_copies: The exact number of models that you would like loaded
+        onto your machine. This can be useful if you exactly know your CPU or
+        GPU capacity and want to maximize resource utilization.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
     """
@@ -132,7 +136,8 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
     if max_batch_duration_secs is not None:
       self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
     self._env_vars = kwargs.get('env_vars', {})
-    self._large_model = large_model
+    self._share_across_processes = large_model or (model_copies is not None)
+    self._model_copies = model_copies or 1
 
   def load_model(self) -> BaseEstimator:
     """Loads and initializes a model for processing."""
@@ -186,7 +191,10 @@ class SklearnModelHandlerNumpy(ModelHandler[numpy.ndarray,
     return self._batching_kwargs
 
   def share_model_across_processes(self) -> bool:
-    return self._large_model
+    return self._share_across_processes
+
+  def model_copies(self) -> int:
+    return self._model_copies
 
 
 PandasInferenceFn = Callable[
@@ -219,6 +227,7 @@ class 
SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
+      model_copies: Optional[int] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for scikit-learn that
     supports pandas dataframes.
@@ -248,6 +257,9 @@ class 
SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
         set this to True if N*W > M.
+      model_copies: The exact number of models that you would like loaded
+        onto your machine. This can be useful if you exactly know your CPU or
+        GPU capacity and want to maximize resource utilization.
       kwargs: 'env_vars' can be used to set environment variables
         before loading the model.
     """
@@ -262,7 +274,8 @@ class 
SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
     if max_batch_duration_secs is not None:
       self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
     self._env_vars = kwargs.get('env_vars', {})
-    self._large_model = large_model
+    self._share_across_processes = large_model or (model_copies is not None)
+    self._model_copies = model_copies or 1
 
   def load_model(self) -> BaseEstimator:
     """Loads and initializes a model for processing."""
@@ -318,4 +331,7 @@ class 
SklearnModelHandlerPandas(ModelHandler[pandas.DataFrame,
     return self._batching_kwargs
 
   def share_model_across_processes(self) -> bool:
-    return self._large_model
+    return self._share_across_processes
+
+  def model_copies(self) -> int:
+    return self._model_copies
diff --git a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py 
b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
index 0802868a1dd..78b59975e63 100644
--- a/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorflow_inference.py
@@ -112,6 +112,7 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
+      model_copies: Optional[int] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for Tensorflow.
 
@@ -137,6 +138,9 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
           memory pressure if you load multiple copies. Given a model that
           consumes N memory and a machine with W cores and M memory, you should
           set this to True if N*W > M.
+        model_copies: The exact number of models that you would like loaded
+          onto your machine. This can be useful if you exactly know your CPU or
+          GPU capacity and want to maximize resource utilization.
         kwargs: 'env_vars' can be used to set environment variables
           before loading the model.
 
@@ -157,7 +161,8 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
       self._batching_kwargs['max_batch_size'] = max_batch_size
     if max_batch_duration_secs is not None:
       self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
-    self._large_model = large_model
+    self._share_across_processes = large_model or (model_copies is not None)
+    self._model_copies = model_copies or 1
 
   def load_model(self) -> tf.Module:
     """Loads and initializes a Tensorflow model for processing."""
@@ -222,7 +227,10 @@ class TFModelHandlerNumpy(ModelHandler[numpy.ndarray,
     return self._batching_kwargs
 
   def share_model_across_processes(self) -> bool:
-    return self._large_model
+    return self._share_across_processes
+
+  def model_copies(self) -> int:
+    return self._model_copies
 
 
 class TFModelHandlerTensor(ModelHandler[tf.Tensor, PredictionResult,
@@ -240,6 +248,7 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, 
PredictionResult,
       max_batch_size: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
       large_model: bool = False,
+      model_copies: Optional[int] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for Tensorflow.
 
@@ -270,6 +279,9 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, 
PredictionResult,
           memory pressure if you load multiple copies. Given a model that
           consumes N memory and a machine with W cores and M memory, you should
           set this to True if N*W > M.
+        model_copies: The exact number of models that you would like loaded
+          onto your machine. This can be useful if you exactly know your CPU or
+          GPU capacity and want to maximize resource utilization.
         kwargs: 'env_vars' can be used to set environment variables
           before loading the model.
 
@@ -290,7 +302,8 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, 
PredictionResult,
       self._batching_kwargs['max_batch_size'] = max_batch_size
     if max_batch_duration_secs is not None:
       self._batching_kwargs["max_batch_duration_secs"] = 
max_batch_duration_secs
-    self._large_model = large_model
+    self._share_across_processes = large_model or (model_copies is not None)
+    self._model_copies = model_copies or 1
 
   def load_model(self) -> tf.Module:
     """Loads and initializes a tensorflow model for processing."""
@@ -355,4 +368,7 @@ class TFModelHandlerTensor(ModelHandler[tf.Tensor, 
PredictionResult,
     return self._batching_kwargs
 
   def share_model_across_processes(self) -> bool:
-    return self._large_model
+    return self._share_across_processes
+
+  def model_copies(self) -> int:
+    return self._model_copies
diff --git a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py 
b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
index 53b81c0c36c..b38947b494c 100644
--- a/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
+++ b/sdks/python/apache_beam/ml/inference/tensorrt_inference.py
@@ -230,6 +230,7 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
       *,
       inference_fn: TensorRTInferenceFn = _default_tensorRT_inference_fn,
       large_model: bool = False,
+      model_copies: Optional[int] = None,
       max_batch_duration_secs: Optional[int] = None,
       **kwargs):
     """Implementation of the ModelHandler interface for TensorRT.
@@ -254,6 +255,9 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
         memory pressure if you load multiple copies. Given a model that
         consumes N memory and a machine with W cores and M memory, you should
         set this to True if N*W > M.
+      model_copies: The exact number of models that you would like loaded
+        onto your machine. This can be useful if you exactly know your CPU or
+        GPU capacity and want to maximize resource utilization.
       max_batch_duration_secs: the maximum amount of time to buffer 
         a batch before emitting; used in streaming contexts.
       kwargs: Additional arguments like 'engine_path' and 'onnx_path' are
@@ -272,7 +276,8 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
     elif 'onnx_path' in kwargs:
       self.onnx_path = kwargs.get('onnx_path')
     self._env_vars = kwargs.get('env_vars', {})
-    self._large_model = large_model
+    self._share_across_processes = large_model or (model_copies is not None)
+    self._model_copies = model_copies or 1
 
   def batch_elements_kwargs(self):
     """Sets min_batch_size and max_batch_size of a TensorRT engine."""
@@ -334,4 +339,7 @@ class TensorRTEngineHandlerNumPy(ModelHandler[np.ndarray,
     return 'BeamML_TensorRT'
 
   def share_model_across_processes(self) -> bool:
-    return self._large_model
+    return self._share_across_processes
+
+  def model_copies(self) -> int:
+    return self._model_copies


Reply via email to