This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new a5f1347699a Add ability to run per key inference (#27857) a5f1347699a is described below commit a5f1347699a3ed142d7d066922c3d1002f0b0f31 Author: Danny McCormick <dannymccorm...@google.com> AuthorDate: Fri Aug 11 18:49:17 2023 -0400 Add ability to run per key inference (#27857) * Add ability to run per key inference * lint * lint * address feedback * lint * Small feedback updates --- sdks/python/apache_beam/ml/inference/base.py | 348 +++++++++++++++------- sdks/python/apache_beam/ml/inference/base_test.py | 65 ++++ 2 files changed, 308 insertions(+), 105 deletions(-) diff --git a/sdks/python/apache_beam/ml/inference/base.py b/sdks/python/apache_beam/ml/inference/base.py index 8acdbaa5da1..5f2b4dc465f 100644 --- a/sdks/python/apache_beam/ml/inference/base.py +++ b/sdks/python/apache_beam/ml/inference/base.py @@ -35,11 +35,13 @@ import threading import time import uuid from collections import OrderedDict +from collections import defaultdict from typing import Any from typing import Callable from typing import Dict from typing import Generic from typing import Iterable +from typing import List from typing import Mapping from typing import NamedTuple from typing import Optional @@ -243,70 +245,278 @@ class ModelHandler(Generic[ExampleT, PredictionT, ModelT]): return False +class _ModelManager: + """ + A class for efficiently managing copies of multiple models. Will load a + single copy of each model into a multi_process_shared object and then + return a lookup key for that object. Optionally takes in a max_models + parameter, if that is set it will only hold that many models in memory at + once before evicting one (using LRU logic). + """ + def __init__( + self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None): + """ + Args: + mh_map: A map from keys to model handlers which can be used to load a + model. + max_models: The maximum number of models to load at any given time + before evicting 1 from memory (using LRU logic). Leave as None to + allow unlimited models. + """ + self._max_models = max_models + self._mh_map: Dict[str, ModelHandler] = mh_map + self._proxy_map: Dict[str, str] = {} + self._tag_map: Dict[ + str, multi_process_shared.MultiProcessShared] = OrderedDict() + + def load(self, key: str) -> str: + """ + Loads the appropriate model for the given key into memory. + Args: + key: the key associated with the model we'd like to load. + Returns: + the tag we can use to access the model using multi_process_shared.py. + """ + # Map the key for a model to a unique tag that will persist until the model + # is released. This needs to be unique between releasing/reacquiring th + # model because otherwise the ProxyManager will try to reuse the model that + # has been released and deleted. + if key in self._tag_map: + self._tag_map.move_to_end(key) + else: + self._tag_map[key] = uuid.uuid4().hex + + tag = self._tag_map[key] + mh = self._mh_map[key] + + if self._max_models is not None and self._max_models < len(self._tag_map): + # If we're about to exceed our LRU size, release the last used model. + tag_to_remove = self._tag_map.popitem(last=False)[1] + shared_handle, model_to_remove = self._proxy_map[tag_to_remove] + shared_handle.release(model_to_remove) + + # Load the new model + shared_handle = multi_process_shared.MultiProcessShared( + mh.load_model, tag=tag) + model_reference = shared_handle.acquire() + self._proxy_map[tag] = (shared_handle, model_reference) + + return tag + + def increment_max_models(self, increment: int): + """ + Increments the number of models that this instance of a _ModelManager is + able to hold. + Args: + increment: the amount by which we are incrementing the number of models. + """ + if self._max_models is None: + raise ValueError( + "Cannot increment max_models if self._max_models is None (unlimited" + + " models mode).") + self._max_models += increment + + +# Use a dataclass instead of named tuple because NamedTuples and generics don't +# mix well across the board for all versions: +# https://github.com/python/typing/issues/653 +class KeyMhMapping(Generic[KeyT, ExampleT, PredictionT, ModelT]): + """ + Dataclass for mapping 1 or more keys to 1 model handler. + Given `KeyMhMapping(['key1', 'key2'], myMh)`, all examples with keys `key1` + or `key2` will be run against the model defined by the `myMh` ModelHandler. + """ + def __init__( + self, keys: List[KeyT], mh: ModelHandler[ExampleT, PredictionT, ModelT]): + self.keys = keys + self.mh = mh + + class KeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], ModelHandler[Tuple[KeyT, ExampleT], Tuple[KeyT, PredictionT], - ModelT]): - def __init__(self, unkeyed: ModelHandler[ExampleT, PredictionT, ModelT]): + Union[ModelT, _ModelManager]]): + def __init__( + self, + unkeyed: Union[ModelHandler[ExampleT, PredictionT, ModelT], + List[KeyMhMapping[KeyT, ExampleT, PredictionT, ModelT]]]): """A ModelHandler that takes keyed examples and returns keyed predictions. For example, if the original model is used with RunInference to take a PCollection[E] to a PCollection[P], this ModelHandler would take a PCollection[Tuple[K, E]] to a PCollection[Tuple[K, P]], making it possible - to use the key to associate the outputs with the inputs. + to use the key to associate the outputs with the inputs. KeyedModelHandler + is able to accept either a single unkeyed ModelHandler or many different + model handlers corresponding to the keys for which that ModelHandler should + be used. For example, the following configuration could be used to map keys + 1-3 to ModelHandler1 and keys 4-5 to ModelHandler2: + + k1 = ['k1', 'k2', 'k3'] + k2 = ['k4', 'k5'] + KeyedModelHandler([KeyMhMapping(k1, mh1), KeyMhMapping(k2, mh2)]) + + Note that a single copy of each of these models may all be held in memory + at the same time; be careful not to load too many large models or your + pipeline may cause Out of Memory exceptions. Args: - unkeyed: An implementation of ModelHandler that does not require keys. + unkeyed: Either (a) an implementation of ModelHandler that does not + require keys or (b) a list of KeyMhMappings mapping lists of keys to + unkeyed ModelHandlers. """ - if len(unkeyed.get_preprocess_fns()) or len(unkeyed.get_postprocess_fns()): - raise Exception( - 'Cannot make make an unkeyed model handler with pre or ' - 'postprocessing functions defined into a keyed model handler. All ' - 'pre/postprocessing functions must be defined on the outer model' - 'handler.') - self._unkeyed = unkeyed - self._env_vars = unkeyed._env_vars - - def load_model(self) -> ModelT: - return self._unkeyed.load_model() + self._single_model = not isinstance(unkeyed, list) + if self._single_model: + if len(unkeyed.get_preprocess_fns()) or len( + unkeyed.get_postprocess_fns()): + raise Exception( + 'Cannot make make an unkeyed model handler with pre or ' + 'postprocessing functions defined into a keyed model handler. All ' + 'pre/postprocessing functions must be defined on the outer model' + 'handler.') + self._env_vars = unkeyed._env_vars + self._unkeyed = unkeyed + return + + # To maintain an efficient representation, we will map all keys in a given + # KeyMhMapping to a single id (the first key in the KeyMhMapping list). + # We will then map that key to a ModelHandler. This will allow us to + # quickly look up the appropriate ModelHandler for any given key. + self._id_to_mh_map: Dict[str, ModelHandler[ExampleT, PredictionT, + ModelT]] = {} + self._key_to_id_map: Dict[str, str] = {} + for mh_tuple in unkeyed: + mh = mh_tuple.mh + keys = mh_tuple.keys + if len(mh.get_preprocess_fns()) or len(mh.get_postprocess_fns()): + raise ValueError( + 'Cannot use an unkeyed model handler with pre or ' + 'postprocessing functions defined in a keyed model handler. All ' + 'pre/postprocessing functions must be defined on the outer model' + 'handler.') + hints = mh.get_resource_hints() + if len(hints) > 0: + logging.warning( + 'mh %s defines the following resource hints, which will be' + 'ignored: %s. Resource hints are not respected when more than one ' + 'model handler is used in a KeyedModelHandler. If you would like ' + 'to specify resource hints, you can do so by overriding the ' + 'KeyedModelHandler.get_resource_hints() method.', + mh, + hints) + batch_kwargs = mh.batch_elements_kwargs() + if len(hints) > 0: + logging.warning( + 'mh %s defines the following batching kwargs which will be ' + 'ignored %s. Batching kwargs are not respected when ' + 'more than one model handler is used in a KeyedModelHandler. If ' + 'you would like to specify resource hints, you can do so by ' + 'overriding the KeyedModelHandler.batch_elements_kwargs() method.', + hints, + batch_kwargs) + env_vars = mh._env_vars + if len(hints) > 0: + logging.warning( + 'mh %s defines the following _env_vars which will be ignored %s. ' + '_env_vars are not respected when more than one model handler is ' + 'used in a KeyedModelHandler. If you need env vars set at ' + 'inference time, you can do so with ' + 'a custom inference function.', + mh, + env_vars) + + if len(keys) == 0: + raise ValueError( + f'Empty list maps to model handler {mh}. All model handlers must ' + 'have one or more associated keys.') + self._id_to_mh_map[keys[0]] = mh + for key in keys: + if key in self._key_to_id_map: + raise ValueError( + f'key {key} maps to multiple model handlers. All keys must map ' + 'to exactly one model handler.') + self._key_to_id_map[key] = keys[0] + + def load_model(self) -> Union[ModelT, _ModelManager]: + if self._single_model: + return self._unkeyed.load_model() + return _ModelManager(self._id_to_mh_map) def run_inference( self, batch: Sequence[Tuple[KeyT, ExampleT]], - model: ModelT, + model: Union[ModelT, _ModelManager], inference_args: Optional[Dict[str, Any]] = None ) -> Iterable[Tuple[KeyT, PredictionT]]: - keys, unkeyed_batch = zip(*batch) - return zip( - keys, self._unkeyed.run_inference(unkeyed_batch, model, inference_args)) + if self._single_model: + keys, unkeyed_batch = zip(*batch) + return zip( + keys, + self._unkeyed.run_inference(unkeyed_batch, model, inference_args)) + + batch_by_key = defaultdict(list) + key_by_id = defaultdict(set) + for key, example in batch: + batch_by_key[key].append(example) + key_by_id[self._key_to_id_map[key]].add(key) + + predictions = [] + for id, keys in key_by_id.items(): + mh = self._id_to_mh_map[id] + keyed_model_tag = model.load(id) + keyed_model_shared_handle = multi_process_shared.MultiProcessShared( + mh.load_model, tag=keyed_model_tag) + keyed_model = keyed_model_shared_handle.acquire() + for key in keys: + unkeyed_batches = batch_by_key[key] + for inf in mh.run_inference(unkeyed_batches, + keyed_model, + inference_args): + predictions.append((key, inf)) + keyed_model_shared_handle.release(keyed_model) + + return predictions def get_num_bytes(self, batch: Sequence[Tuple[KeyT, ExampleT]]) -> int: - keys, unkeyed_batch = zip(*batch) - return len(pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch) + if self._single_model: + keys, unkeyed_batch = zip(*batch) + return len( + pickle.dumps(keys)) + self._unkeyed.get_num_bytes(unkeyed_batch) + return len(pickle.dumps(batch)) def get_metrics_namespace(self) -> str: - return self._unkeyed.get_metrics_namespace() + if self._single_model: + return self._unkeyed.get_metrics_namespace() + return 'BeamML_KeyedModels' def get_resource_hints(self): - return self._unkeyed.get_resource_hints() + if self._single_model: + return self._unkeyed.get_resource_hints() + return {} def batch_elements_kwargs(self): - return self._unkeyed.batch_elements_kwargs() + if self._single_model: + return self._unkeyed.batch_elements_kwargs() + return {} def validate_inference_args(self, inference_args: Optional[Dict[str, Any]]): - return self._unkeyed.validate_inference_args(inference_args) + if self._single_model: + return self._unkeyed.validate_inference_args(inference_args) + for mh in self._id_to_mh_map.values(): + mh.validate_inference_args(inference_args) def update_model_path(self, model_path: Optional[str] = None): - return self._unkeyed.update_model_path(model_path=model_path) - - def get_preprocess_fns(self) -> Iterable[Callable[[Any], Any]]: - return self._unkeyed.get_preprocess_fns() - - def get_postprocess_fns(self) -> Iterable[Callable[[Any], Any]]: - return self._unkeyed.get_postprocess_fns() + if self._single_model: + return self._unkeyed.update_model_path(model_path=model_path) + if model_path is not None: + raise RuntimeError( + 'Model updates are currently not supported for ' + + 'KeyedModelHandlers with multiple different per-key ' + + 'ModelHandlers.') def share_model_across_processes(self) -> bool: - return self._unkeyed.share_model_across_processes() + if self._single_model: + return self._unkeyed.share_model_across_processes() + return True class MaybeKeyedModelHandler(Generic[KeyT, ExampleT, PredictionT, ModelT], @@ -740,78 +950,6 @@ class RunInference(beam.PTransform[beam.PCollection[ExampleT], return self -class _ModelManager: - """ - A class for efficiently managing copies of multiple models. Will load a - single copy of each model into a multi_process_shared object and then - return a lookup key for that object. Optionally takes in a max_models - parameter, if that is set it will only hold that many models in memory at - once before evicting one (using LRU logic). - """ - def __init__( - self, mh_map: Dict[str, ModelHandler], max_models: Optional[int] = None): - """ - Args: - mh_map: A map from keys to model handlers which can be used to load a - model. - max_models: The maximum number of models to load at any given time - before evicting 1 from memory (using LRU logic). Leave as None to - allow unlimited models. - """ - self._max_models = max_models - self._mh_map: Dict[str, ModelHandler] = mh_map - self._proxy_map: Dict[str, str] = {} - self._tag_map: Dict[ - str, multi_process_shared.MultiProcessShared] = OrderedDict() - - def load(self, key: str) -> str: - """ - Loads the appropriate model for the given key into memory. - Args: - key: the key associated with the model we'd like to load. - Returns: - the tag we can use to access the model using multi_process_shared.py. - """ - # Map the key for a model to a unique tag that will persist until the model - # is released. This needs to be unique between releasing/reacquiring th - # model because otherwise the ProxyManager will try to reuse the model that - # has been released and deleted. - if key in self._tag_map: - self._tag_map.move_to_end(key) - else: - self._tag_map[key] = uuid.uuid4().hex - - tag = self._tag_map[key] - mh = self._mh_map[key] - - if self._max_models is not None and self._max_models < len(self._tag_map): - # If we're about to exceed our LRU size, release the last used model. - tag_to_remove = self._tag_map.popitem(last=False)[1] - shared_handle, model_to_remove = self._proxy_map[tag_to_remove] - shared_handle.release(model_to_remove) - - # Load the new model - shared_handle = multi_process_shared.MultiProcessShared( - mh.load_model, tag=tag) - model_reference = shared_handle.acquire() - self._proxy_map[tag] = (shared_handle, model_reference) - - return tag - - def increment_max_models(self, increment: int): - """ - Increments the number of models that this instance of a _ModelManager is - able to hold. - Args: - increment: the amount by which we are incrementing the number of models. - """ - if self._max_models is None: - raise ValueError( - "Cannot increment max_models if self._max_models is None (unlimited" + - " models mode).") - self._max_models += increment - - class _MetricsCollector: """ A metrics collector that tracks ML related performance and memory usage. diff --git a/sdks/python/apache_beam/ml/inference/base_test.py b/sdks/python/apache_beam/ml/inference/base_test.py index be91efb9479..c79189718a9 100644 --- a/sdks/python/apache_beam/ml/inference/base_test.py +++ b/sdks/python/apache_beam/ml/inference/base_test.py @@ -254,6 +254,71 @@ class RunInferenceBaseTest(unittest.TestCase): base.KeyedModelHandler(FakeModelHandler())) assert_that(actual, equal_to(expected), label='assert:inferences') + def test_run_inference_impl_with_keyed_examples_many_model_handlers(self): + with TestPipeline() as pipeline: + examples = [1, 5, 3, 10] + keyed_examples = [(i, example) for i, example in enumerate(examples)] + expected = [(i, example + 1) for i, example in enumerate(examples)] + expected[0] = (0, 200) + pcoll = pipeline | 'start' >> beam.Create(keyed_examples) + mhs = [ + base.KeyMhMapping([0], + FakeModelHandler( + state=200, multi_process_shared=True)), + base.KeyMhMapping([1, 2, 3], + FakeModelHandler(multi_process_shared=True)) + ] + actual = pcoll | base.RunInference(base.KeyedModelHandler(mhs)) + assert_that(actual, equal_to(expected), label='assert:inferences') + + def test_keyed_many_model_handlers_validation(self): + def mult_two(example: str) -> int: + return int(example) * 2 + + mhs = [ + base.KeyMhMapping( + [0], + FakeModelHandler( + state=200, + multi_process_shared=True).with_preprocess_fn(mult_two)), + base.KeyMhMapping([1, 2, 3], + FakeModelHandler(multi_process_shared=True)) + ] + with self.assertRaises(ValueError): + base.KeyedModelHandler(mhs) + + mhs = [ + base.KeyMhMapping( + [0], + FakeModelHandler( + state=200, + multi_process_shared=True).with_postprocess_fn(mult_two)), + base.KeyMhMapping([1, 2, 3], + FakeModelHandler(multi_process_shared=True)) + ] + with self.assertRaises(ValueError): + base.KeyedModelHandler(mhs) + + mhs = [ + base.KeyMhMapping([0], + FakeModelHandler( + state=200, multi_process_shared=True)), + base.KeyMhMapping([0, 1, 2, 3], + FakeModelHandler(multi_process_shared=True)) + ] + with self.assertRaises(ValueError): + base.KeyedModelHandler(mhs) + + mhs = [ + base.KeyMhMapping([], + FakeModelHandler( + state=200, multi_process_shared=True)), + base.KeyMhMapping([0, 1, 2, 3], + FakeModelHandler(multi_process_shared=True)) + ] + with self.assertRaises(ValueError): + base.KeyedModelHandler(mhs) + def test_run_inference_impl_with_maybe_keyed_examples(self): with TestPipeline() as pipeline: examples = [1, 5, 3, 10]