This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch users/damccorm/mpsRi
in repository https://gitbox.apache.org/repos/asf/beam.git

commit dbdc95a6552b242563e3d7c9a08d9f95449ca869
Author: Danny McCormick <dannymccorm...@google.com>
AuthorDate: Fri May 12 21:08:29 2023 -0400

    Allow model handlers to request multi_process_shared model
---
 sdks/python/apache_beam/ml/inference/base.py      |  36 ++++-
 sdks/python/apache_beam/ml/inference/base_test.py | 174 +++++++++++++++++++++-
 2 files changed, 205 insertions(+), 5 deletions(-)

diff --git a/sdks/python/apache_beam/ml/inference/base.py 
b/sdks/python/apache_beam/ml/inference/base.py
index 0a62c26887b..a60f7365b42 100644
--- a/sdks/python/apache_beam/ml/inference/base.py
+++ b/sdks/python/apache_beam/ml/inference/base.py
@@ -33,6 +33,7 @@ import pickle
 import sys
 import threading
 import time
+import uuid
 from typing import Any
 from typing import Callable
 from typing import Dict
@@ -47,6 +48,7 @@ from typing import TypeVar
 from typing import Union
 
 import apache_beam as beam
+from apache_beam.utils import multi_process_shared
 from apache_beam.utils import shared
 
 try:
@@ -227,6 +229,15 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]):
     inference result in order from first applied to last applied."""
     return _PostProcessingModelHandler(self, fn)
 
+  def share_model_across_processes(self) -> bool:
+    """Returns a boolean representing whether or not a model should
+    be shared across multiple processes instead of being loaded per process.
+    This is primary useful for large models that  can't fit multiple copies in
+    memory. Multi-process support may vary by runner, but this will fallback to
+    loading per process as necessary. See
+    
https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html""";
+    return False
+
 
 class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
                         ModelHandler[Tuple[KeyT, ExampleT],
@@ -290,6 +301,9 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
   def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
     return self._unkeyed.get_postprocess_fns()
 
+  def share_model_across_processes(self) -> bool:
+    return self._unkeyed.share_model_across_processes()
+
 
 class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT],
                              ModelHandler[Union[ExampleT, Tuple[KeyT,
@@ -379,6 +393,9 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, 
PredictionT, ModelT],
   def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]:
     return self._unkeyed.get_postprocess_fns()
 
+  def share_model_across_processes(self) -> bool:
+    return self._unkeyed.share_model_across_processes()
+
 
 class _PreProcessingModelHandler(Generic[ExampleT,
                                          PredictionT,
@@ -538,6 +555,9 @@ class 
RunInference(beam.PTransform[beam.PCollection[ExampleT],
     self._with_exception_handling = False
     self._watch_model_pattern = watch_model_pattern
     self._kwargs = kwargs
+    # Generate a random tag to use for shared.py and multi_process_shared.py to
+    # allow us to effectively disambiguate in multi-model settings.
+    self._model_tag = uuid.uuid4().hex
 
   def _get_model_metadata_pcoll(self, pipeline):
     # avoid circular imports.
@@ -623,7 +643,8 @@ class 
RunInference(beam.PTransform[beam.PCollection[ExampleT],
             self._model_handler,
             self._clock,
             self._metrics_namespace,
-            self._enable_side_input_loading),
+            self._enable_side_input_loading,
+            self._model_tag),
         self._inference_args,
         beam.pvalue.AsSingleton(
             self._model_metadata_pcoll,
@@ -780,7 +801,8 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, 
PredictionT]):
       model_handler: ModelHandler[ExampleT, PredictionT, Any],
       clock,
       metrics_namespace,
-      enable_side_input_loading: bool = False):
+      enable_side_input_loading: bool = False,
+      model_tag: str = "RunInference"):
     """A DoFn implementation generic to frameworks.
 
       Args:
@@ -789,6 +811,7 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, 
PredictionT]):
         metrics_namespace: Namespace of the transform to collect metrics.
         enable_side_input_loading: Bool to indicate if model updates
             with side inputs.
+        model_tag: Tag to use to disambiguate models in multi-model settings.
     """
     self._model_handler = model_handler
     self._shared_model_handle = shared.Shared()
@@ -797,6 +820,7 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, 
PredictionT]):
     self._metrics_namespace = metrics_namespace
     self._enable_side_input_loading = enable_side_input_loading
     self._side_input_path = None
+    self._model_tag = model_tag
 
   def _load_model(self, side_input_model_path: Optional[str] = None):
     def load():
@@ -815,7 +839,13 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, 
PredictionT]):
 
     # TODO(https://github.com/apache/beam/issues/21443): Investigate releasing
     # model.
-    model = self._shared_model_handle.acquire(load, tag=side_input_model_path)
+    if self._model_handler.share_model_across_processes():
+      # TODO - make this a more robust tag than 'RunInference'
+      model = multi_process_shared.MultiProcessShared(
+          load, tag=side_input_model_path or self._model_tag).acquire()
+    else:
+      model = self._shared_model_handle.acquire(
+          load, tag=side_input_model_path or self._model_tag)
     # since shared_model_handle is shared across threads, the model path
     # might not get updated in the model handler
     # because we directly get cached weak ref model from shared cache, instead
diff --git a/sdks/python/apache_beam/ml/inference/base_test.py 
b/sdks/python/apache_beam/ml/inference/base_test.py
index 455dfe208b1..afd336f9a47 100644
--- a/sdks/python/apache_beam/ml/inference/base_test.py
+++ b/sdks/python/apache_beam/ml/inference/base_test.py
@@ -50,11 +50,17 @@ class FakeModel:
 
 class FakeModelHandler(base.ModelHandler[int, int, FakeModel]):
   def __init__(
-      self, clock=None, min_batch_size=1, max_batch_size=9999, **kwargs):
+      self,
+      clock=None,
+      min_batch_size=1,
+      max_batch_size=9999,
+      multi_process_shared=False,
+      **kwargs):
     self._fake_clock = clock
     self._min_batch_size = min_batch_size
     self._max_batch_size = max_batch_size
     self._env_vars = kwargs.get('env_vars', {})
+    self._multi_process_shared = multi_process_shared
 
   def load_model(self):
     if self._fake_clock:
@@ -66,6 +72,12 @@ class FakeModelHandler(base.ModelHandler[int, int, 
FakeModel]):
       batch: Sequence[int],
       model: FakeModel,
       inference_args=None) -> Iterable[int]:
+    multi_process_shared_loaded = "multi_process_shared" in str(type(model))
+    if self._multi_process_shared != multi_process_shared_loaded:
+      raise Exception(
+          f'Loaded model of type {type(model)}, was' +
+          f'{"" if self._multi_process_shared else " not"} ' +
+          'expecting multi_process_shared_model')
     if self._fake_clock:
       self._fake_clock.current_time_ns += 3_000_000  # 3 milliseconds
     for example in batch:
@@ -80,13 +92,21 @@ class FakeModelHandler(base.ModelHandler[int, int, 
FakeModel]):
         'max_batch_size': self._max_batch_size
     }
 
+  def share_model_across_processes(self):
+    return self._multi_process_shared
+
 
 class FakeModelHandlerReturnsPredictionResult(
     base.ModelHandler[int, base.PredictionResult, FakeModel]):
-  def __init__(self, clock=None, model_id='fake_model_id_default'):
+  def __init__(
+      self,
+      clock=None,
+      model_id='fake_model_id_default',
+      multi_process_shared=False):
     self.model_id = model_id
     self._fake_clock = clock
     self._env_vars = {}
+    self._multi_process_shared = multi_process_shared
 
   def load_model(self):
     return FakeModel()
@@ -96,6 +116,12 @@ class FakeModelHandlerReturnsPredictionResult(
       batch: Sequence[int],
       model: FakeModel,
       inference_args=None) -> Iterable[base.PredictionResult]:
+    multi_process_shared_loaded = "multi_process_shared" in str(type(model))
+    if self._multi_process_shared != multi_process_shared_loaded:
+      raise Exception(
+          f'Loaded model of type {type(model)}, was' +
+          f'{"" if self._multi_process_shared else " not"} ' +
+          'expecting multi_process_shared_model')
     for example in batch:
       yield base.PredictionResult(
           model_id=self.model_id,
@@ -105,6 +131,9 @@ class FakeModelHandlerReturnsPredictionResult(
   def update_model_path(self, model_path: Optional[str] = None):
     self.model_id = model_path if model_path else self.model_id
 
+  def share_model_across_processes(self):
+    return self._multi_process_shared
+
 
 class FakeClock:
   def __init__(self):
@@ -156,6 +185,15 @@ class RunInferenceBaseTest(unittest.TestCase):
       actual = pcoll | base.RunInference(FakeModelHandler())
       assert_that(actual, equal_to(expected), label='assert:inferences')
 
+  def test_run_inference_impl_simple_examples_multi_process_shared(self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      expected = [example + 1 for example in examples]
+      pcoll = pipeline | 'start' >> beam.Create(examples)
+      actual = pcoll | base.RunInference(
+          FakeModelHandler(multi_process_shared=True))
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
   def test_run_inference_impl_with_keyed_examples(self):
     with TestPipeline() as pipeline:
       examples = [1, 5, 3, 10]
@@ -183,6 +221,35 @@ class RunInferenceBaseTest(unittest.TestCase):
           model_handler)
       assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
 
+  def test_run_inference_impl_with_keyed_examples_multi_process_shared(self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      keyed_examples = [(i, example) for i, example in enumerate(examples)]
+      expected = [(i, example + 1) for i, example in enumerate(examples)]
+      pcoll = pipeline | 'start' >> beam.Create(keyed_examples)
+      actual = pcoll | base.RunInference(
+          base.KeyedModelHandler(FakeModelHandler(multi_process_shared=True)))
+      assert_that(actual, equal_to(expected), label='assert:inferences')
+
+  def test_run_inference_impl_with_maybe_keyed_examples_multi_process_shared(
+      self):
+    with TestPipeline() as pipeline:
+      examples = [1, 5, 3, 10]
+      keyed_examples = [(i, example) for i, example in enumerate(examples)]
+      expected = [example + 1 for example in examples]
+      keyed_expected = [(i, example + 1) for i, example in enumerate(examples)]
+      model_handler = base.MaybeKeyedModelHandler(
+          FakeModelHandler(multi_process_shared=True))
+
+      pcoll = pipeline | 'Unkeyed' >> beam.Create(examples)
+      actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler)
+      assert_that(actual, equal_to(expected), label='CheckUnkeyed')
+
+      keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples)
+      keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference(
+          model_handler)
+      assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed')
+
   def test_run_inference_preprocessing(self):
     def mult_two(example: str) -> int:
       return int(example) * 2
@@ -634,6 +701,31 @@ class RunInferenceBaseTest(unittest.TestCase):
         'singleton view. First two elements encountered are' in str(
             e.exception))
 
+  def test_run_inference_with_iterable_side_input_multi_process_shared(self):
+    test_pipeline = TestPipeline()
+    side_input = (
+        test_pipeline | "CreateDummySideInput" >> beam.Create(
+            [base.ModelMetadata(1, 1), base.ModelMetadata(2, 2)])
+        | "ApplySideInputWindow" >> beam.WindowInto(
+            window.GlobalWindows(),
+            trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)),
+            accumulation_mode=trigger.AccumulationMode.DISCARDING))
+
+    test_pipeline.options.view_as(StandardOptions).streaming = True
+    with self.assertRaises(ValueError) as e:
+      _ = (
+          test_pipeline
+          | beam.Create([1, 2, 3, 4])
+          | base.RunInference(
+              FakeModelHandler(multi_process_shared=True),
+              model_metadata_pcoll=side_input))
+      test_pipeline.run()
+
+    self.assertTrue(
+        'PCollection of size 2 with more than one element accessed as a '
+        'singleton view. First two elements encountered are' in str(
+            e.exception))
+
   def test_run_inference_empty_side_input(self):
     model_handler = FakeModelHandlerReturnsPredictionResult()
     main_input_elements = [1, 2]
@@ -727,6 +819,84 @@ class RunInferenceBaseTest(unittest.TestCase):
 
       assert_that(result_pcoll, equal_to(expected_result))
 
+  def test_run_inference_side_input_in_batch_multi_process_shared(self):
+    first_ts = math.floor(time.time()) - 30
+    interval = 7
+
+    sample_main_input_elements = ([
+        first_ts - 2,
+        first_ts + 1,
+        first_ts + 8,
+        first_ts + 15,
+        first_ts + 22,
+    ])
+
+    sample_side_input_elements = [
+        (first_ts + 1, base.ModelMetadata(model_id='', model_name='')),
+        # if model_id is empty string, we use the default model
+        # handler model URI.
+        (
+            first_ts + 8,
+            base.ModelMetadata(
+                model_id='fake_model_id_1', model_name='fake_model_id_1')),
+        (
+            first_ts + 15,
+            base.ModelMetadata(
+                model_id='fake_model_id_2', model_name='fake_model_id_2'))
+    ]
+
+    model_handler = FakeModelHandlerReturnsPredictionResult(
+        multi_process_shared=True)
+
+    # applying GroupByKey to utilize windowing according to
+    # 
https://beam.apache.org/documentation/programming-guide/#windowing-bounded-collections
+    class _EmitElement(beam.DoFn):
+      def process(self, element):
+        for e in element:
+          yield e
+
+    with TestPipeline() as pipeline:
+      side_input = (
+          pipeline
+          |
+          "CreateSideInputElements" >> beam.Create(sample_side_input_elements)
+          | beam.Map(lambda x: TimestampedValue(x[1], x[0]))
+          | beam.WindowInto(
+              window.FixedWindows(interval),
+              accumulation_mode=trigger.AccumulationMode.DISCARDING)
+          | beam.Map(lambda x: ('key', x))
+          | beam.GroupByKey()
+          | beam.Map(lambda x: x[1])
+          | "EmitSideInput" >> beam.ParDo(_EmitElement()))
+
+      result_pcoll = (
+          pipeline
+          | beam.Create(sample_main_input_elements)
+          | "MapTimeStamp" >> beam.Map(lambda x: TimestampedValue(x, x))
+          | "ApplyWindow" >> beam.WindowInto(window.FixedWindows(interval))
+          | beam.Map(lambda x: ('key', x))
+          | "MainInputGBK" >> beam.GroupByKey()
+          | beam.Map(lambda x: x[1])
+          | beam.ParDo(_EmitElement())
+          | "RunInference" >> base.RunInference(
+              model_handler, model_metadata_pcoll=side_input))
+
+      expected_model_id_order = [
+          'fake_model_id_default',
+          'fake_model_id_default',
+          'fake_model_id_1',
+          'fake_model_id_2',
+          'fake_model_id_2'
+      ]
+      expected_result = [
+          base.PredictionResult(
+              example=sample_main_input_elements[i],
+              inference=sample_main_input_elements[i] + 1,
+              model_id=expected_model_id_order[i]) for i in range(5)
+      ]
+
+      assert_that(result_pcoll, equal_to(expected_result))
+
   @unittest.skipIf(
       not TestPipeline().get_pipeline_options().view_as(
           StandardOptions).streaming,

Reply via email to