This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch users/damccorm/mpsRi in repository https://gitbox.apache.org/repos/asf/beam.git
commit dbdc95a6552b242563e3d7c9a08d9f95449ca869 Author: Danny McCormick <dannymccorm...@google.com> AuthorDate: Fri May 12 21:08:29 2023 -0400 Allow model handlers to request multi_process_shared model --- sdks/python/apache_beam/ml/inference/base.py | 36 ++++- sdks/python/apache_beam/ml/inference/base_test.py | 174 +++++++++++++++++++++- 2 files changed, 205 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 0a62c26887b..a60f7365b42 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -33,6 +33,7 @@ import pickle import sys import threading import time +import uuid from typing import Any from typing import Callable from typing import Dict @@ -47,6 +48,7 @@ from typing import TypeVar from typing import Union import apache_beam as beam +from apache_beam.utils import multi_process_shared from apache_beam.utils import shared try: @@ -227,6 +229,15 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]): inference result in order from first applied to last applied.""" return _PostProcessingModelHandler(self, fn) + def share_model_across_processes(self) -> bool: + """Returns a boolean representing whether or not a model should + be shared across multiple processes instead of being loaded per process. + This is primary useful for large models that can't fit multiple copies in + memory. Multi-process support may vary by runner, but this will fallback to + loading per process as necessary. See + https://beam.apache.org/releases/pydoc/current/apache_beam.utils.multi_process_shared.html""" + return False + class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], ModelHandler[Tuple[KeyT, ExampleT], @@ -290,6 +301,9 @@ class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: return self._unkeyed.get_postprocess_fns() + def share_model_across_processes(self) -> bool: + return self._unkeyed.share_model_across_processes() + class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], ModelHandler[Union[ExampleT, Tuple[KeyT, @@ -379,6 +393,9 @@ class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: return self._unkeyed.get_postprocess_fns() + def share_model_across_processes(self) -> bool: + return self._unkeyed.share_model_across_processes() + class _PreProcessingModelHandler(Generic[ExampleT, PredictionT, @@ -538,6 +555,9 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT], self._with_exception_handling = False self._watch_model_pattern = watch_model_pattern self._kwargs = kwargs + # Generate a random tag to use for shared.py and multi_process_shared.py to + # allow us to effectively disambiguate in multi-model settings. + self._model_tag = uuid.uuid4().hex def _get_model_metadata_pcoll(self, pipeline): # avoid circular imports. @@ -623,7 +643,8 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT], self._model_handler, self._clock, self._metrics_namespace, - self._enable_side_input_loading), + self._enable_side_input_loading, + self._model_tag), self._inference_args, beam.pvalue.AsSingleton( self._model_metadata_pcoll, @@ -780,7 +801,8 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]): model_handler: ModelHandler[ExampleT, PredictionT, Any], clock, metrics_namespace, - enable_side_input_loading: bool = False): + enable_side_input_loading: bool = False, + model_tag: str = "RunInference"): """A DoFn implementation generic to frameworks. Args: @@ -789,6 +811,7 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]): metrics_namespace: Namespace of the transform to collect metrics. enable_side_input_loading: Bool to indicate if model updates with side inputs. + model_tag: Tag to use to disambiguate models in multi-model settings. """ self._model_handler = model_handler self._shared_model_handle = shared.Shared() @@ -797,6 +820,7 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]): self._metrics_namespace = metrics_namespace self._enable_side_input_loading = enable_side_input_loading self._side_input_path = None + self._model_tag = model_tag def _load_model(self, side_input_model_path: Optional[str] = None): def load(): @@ -815,7 +839,13 @@ class _RunInferenceDoFn(beam.DoFn, Generic[ExampleT, PredictionT]): # TODO(https://github.com/apache/beam/issues/21443): Investigate releasing # model. - model = self._shared_model_handle.acquire(load, tag=side_input_model_path) + if self._model_handler.share_model_across_processes(): + # TODO - make this a more robust tag than 'RunInference' + model = multi_process_shared.MultiProcessShared( + load, tag=side_input_model_path or self._model_tag).acquire() + else: + model = self._shared_model_handle.acquire( + load, tag=side_input_model_path or self._model_tag) # since shared_model_handle is shared across threads, the model path # might not get updated in the model handler # because we directly get cached weak ref model from shared cache, instead diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index 455dfe208b1..afd336f9a47 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -50,11 +50,17 @@ class FakeModel: class FakeModelHandler(base.ModelHandler[int, int, FakeModel]): def __init__( - self, clock=None, min_batch_size=1, max_batch_size=9999, **kwargs): + self, + clock=None, + min_batch_size=1, + max_batch_size=9999, + multi_process_shared=False, + **kwargs): self._fake_clock = clock self._min_batch_size = min_batch_size self._max_batch_size = max_batch_size self._env_vars = kwargs.get('env_vars', {}) + self._multi_process_shared = multi_process_shared def load_model(self): if self._fake_clock: @@ -66,6 +72,12 @@ class FakeModelHandler(base.ModelHandler[int, int, FakeModel]): batch: Sequence[int], model: FakeModel, inference_args=None) -> Iterable[int]: + multi_process_shared_loaded = "multi_process_shared" in str(type(model)) + if self._multi_process_shared != multi_process_shared_loaded: + raise Exception( + f'Loaded model of type {type(model)}, was' + + f'{"" if self._multi_process_shared else " not"} ' + + 'expecting multi_process_shared_model') if self._fake_clock: self._fake_clock.current_time_ns += 3_000_000 # 3 milliseconds for example in batch: @@ -80,13 +92,21 @@ class FakeModelHandler(base.ModelHandler[int, int, FakeModel]): 'max_batch_size': self._max_batch_size } + def share_model_across_processes(self): + return self._multi_process_shared + class FakeModelHandlerReturnsPredictionResult( base.ModelHandler[int, base.PredictionResult, FakeModel]): - def __init__(self, clock=None, model_id='fake_model_id_default'): + def __init__( + self, + clock=None, + model_id='fake_model_id_default', + multi_process_shared=False): self.model_id = model_id self._fake_clock = clock self._env_vars = {} + self._multi_process_shared = multi_process_shared def load_model(self): return FakeModel() @@ -96,6 +116,12 @@ class FakeModelHandlerReturnsPredictionResult( batch: Sequence[int], model: FakeModel, inference_args=None) -> Iterable[base.PredictionResult]: + multi_process_shared_loaded = "multi_process_shared" in str(type(model)) + if self._multi_process_shared != multi_process_shared_loaded: + raise Exception( + f'Loaded model of type {type(model)}, was' + + f'{"" if self._multi_process_shared else " not"} ' + + 'expecting multi_process_shared_model') for example in batch: yield base.PredictionResult( model_id=self.model_id, @@ -105,6 +131,9 @@ class FakeModelHandlerReturnsPredictionResult( def update_model_path(self, model_path: Optional[str] = None): self.model_id = model_path if model_path else self.model_id + def share_model_across_processes(self): + return self._multi_process_shared + class FakeClock: def __init__(self): @@ -156,6 +185,15 @@ class RunInferenceBaseTest(unittest.TestCase): actual = pcoll | base.RunInference(FakeModelHandler()) assert_that(actual, equal_to(expected), label='assert:inferences') + def test_run_inference_impl_simple_examples_multi_process_shared(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + expected = [example + 1 for example in examples] + pcoll = pipeline | 'start' >> beam.Create(examples) + actual = pcoll | base.RunInference( + FakeModelHandler(multi_process_shared=True)) + assert_that(actual, equal_to(expected), label='assert:inferences') + def test_run_inference_impl_with_keyed_examples(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10] @@ -183,6 +221,35 @@ class RunInferenceBaseTest(unittest.TestCase): model_handler) assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed') + def test_run_inference_impl_with_keyed_examples_multi_process_shared(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + keyed_examples = [(i, example) for i, example in enumerate(examples)] + expected = [(i, example + 1) for i, example in enumerate(examples)] + pcoll = pipeline | 'start' >> beam.Create(keyed_examples) + actual = pcoll | base.RunInference( + base.KeyedModelHandler(FakeModelHandler(multi_process_shared=True))) + assert_that(actual, equal_to(expected), label='assert:inferences') + + def test_run_inference_impl_with_maybe_keyed_examples_multi_process_shared( + self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + keyed_examples = [(i, example) for i, example in enumerate(examples)] + expected = [example + 1 for example in examples] + keyed_expected = [(i, example + 1) for i, example in enumerate(examples)] + model_handler = base.MaybeKeyedModelHandler( + FakeModelHandler(multi_process_shared=True)) + + pcoll = pipeline | 'Unkeyed' >> beam.Create(examples) + actual = pcoll | 'RunUnkeyed' >> base.RunInference(model_handler) + assert_that(actual, equal_to(expected), label='CheckUnkeyed') + + keyed_pcoll = pipeline | 'Keyed' >> beam.Create(keyed_examples) + keyed_actual = keyed_pcoll | 'RunKeyed' >> base.RunInference( + model_handler) + assert_that(keyed_actual, equal_to(keyed_expected), label='CheckKeyed') + def test_run_inference_preprocessing(self): def mult_two(example: str) -> int: return int(example) * 2 @@ -634,6 +701,31 @@ class RunInferenceBaseTest(unittest.TestCase): 'singleton view. First two elements encountered are' in str( e.exception)) + def test_run_inference_with_iterable_side_input_multi_process_shared(self): + test_pipeline = TestPipeline() + side_input = ( + test_pipeline | "CreateDummySideInput" >> beam.Create( + [base.ModelMetadata(1, 1), base.ModelMetadata(2, 2)]) + | "ApplySideInputWindow" >> beam.WindowInto( + window.GlobalWindows(), + trigger=trigger.Repeatedly(trigger.AfterProcessingTime(1)), + accumulation_mode=trigger.AccumulationMode.DISCARDING)) + + test_pipeline.options.view_as(StandardOptions).streaming = True + with self.assertRaises(ValueError) as e: + _ = ( + test_pipeline + | beam.Create([1, 2, 3, 4]) + | base.RunInference( + FakeModelHandler(multi_process_shared=True), + model_metadata_pcoll=side_input)) + test_pipeline.run() + + self.assertTrue( + 'PCollection of size 2 with more than one element accessed as a ' + 'singleton view. First two elements encountered are' in str( + e.exception)) + def test_run_inference_empty_side_input(self): model_handler = FakeModelHandlerReturnsPredictionResult() main_input_elements = [1, 2] @@ -727,6 +819,84 @@ class RunInferenceBaseTest(unittest.TestCase): assert_that(result_pcoll, equal_to(expected_result)) + def test_run_inference_side_input_in_batch_multi_process_shared(self): + first_ts = math.floor(time.time()) - 30 + interval = 7 + + sample_main_input_elements = ([ + first_ts - 2, + first_ts + 1, + first_ts + 8, + first_ts + 15, + first_ts + 22, + ]) + + sample_side_input_elements = [ + (first_ts + 1, base.ModelMetadata(model_id='', model_name='')), + # if model_id is empty string, we use the default model + # handler model URI. + ( + first_ts + 8, + base.ModelMetadata( + model_id='fake_model_id_1', model_name='fake_model_id_1')), + ( + first_ts + 15, + base.ModelMetadata( + model_id='fake_model_id_2', model_name='fake_model_id_2')) + ] + + model_handler = FakeModelHandlerReturnsPredictionResult( + multi_process_shared=True) + + # applying GroupByKey to utilize windowing according to + # https://beam.apache.org/documentation/programming-guide/#windowing-bounded-collections + class _EmitElement(beam.DoFn): + def process(self, element): + for e in element: + yield e + + with TestPipeline() as pipeline: + side_input = ( + pipeline + | + "CreateSideInputElements" >> beam.Create(sample_side_input_elements) + | beam.Map(lambda x: TimestampedValue(x[1], x[0])) + | beam.WindowInto( + window.FixedWindows(interval), + accumulation_mode=trigger.AccumulationMode.DISCARDING) + | beam.Map(lambda x: ('key', x)) + | beam.GroupByKey() + | beam.Map(lambda x: x[1]) + | "EmitSideInput" >> beam.ParDo(_EmitElement())) + + result_pcoll = ( + pipeline + | beam.Create(sample_main_input_elements) + | "MapTimeStamp" >> beam.Map(lambda x: TimestampedValue(x, x)) + | "ApplyWindow" >> beam.WindowInto(window.FixedWindows(interval)) + | beam.Map(lambda x: ('key', x)) + | "MainInputGBK" >> beam.GroupByKey() + | beam.Map(lambda x: x[1]) + | beam.ParDo(_EmitElement()) + | "RunInference" >> base.RunInference( + model_handler, model_metadata_pcoll=side_input)) + + expected_model_id_order = [ + 'fake_model_id_default', + 'fake_model_id_default', + 'fake_model_id_1', + 'fake_model_id_2', + 'fake_model_id_2' + ] + expected_result = [ + base.PredictionResult( + example=sample_main_input_elements[i], + inference=sample_main_input_elements[i] + 1, + model_id=expected_model_id_order[i]) for i in range(5) + ] + + assert_that(result_pcoll, equal_to(expected_result)) + @unittest.skipIf( not TestPipeline().get_pipeline_options().view_as( StandardOptions).streaming,