damccorm commented on code in PR #33406:
URL: https://github.com/apache/beam/pull/33406#discussion_r1893997507
##########
sdks/python/apache_beam/yaml/yaml_ml.py:
##########
@@ -33,11 +40,411 @@
tft = None # type: ignore
+class ModelHandlerProvider:
+ handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {}
+
+ def __init__(
+ self,
+ handler,
+ preprocess: Optional[Dict[str, str]] = None,
+ postprocess: Optional[Dict[str, str]] = None):
+ self._handler = handler
+ self._preprocess_fn = self.parse_processing_transform(
+ preprocess, 'preprocess') or self.default_preprocess_fn()
+ self._postprocess_fn = self.parse_processing_transform(
+ postprocess, 'postprocess') or self.default_postprocess_fn()
+
+ def inference_output_type(self):
+ return Any
+
+ @staticmethod
+ def parse_processing_transform(processing_transform, typ):
+ def _parse_config(callable=None, path=None, name=None):
+ if callable and (path or name):
+ raise ValueError(
+ f"Cannot specify 'callable' with 'path' and 'name' for {typ} "
+ f"function.")
+ if path and name:
+ return python_callable.PythonCallableWithSource.load_from_script(
+ FileSystems.open(path).read().decode(), name)
+ elif callable:
+ return python_callable.PythonCallableWithSource(callable)
+ else:
+ raise ValueError(
+ f"Must specify one of 'callable' or 'path' and 'name' for {typ} "
+ f"function.")
+
+ if processing_transform:
+ if isinstance(processing_transform, dict):
+ return _parse_config(**processing_transform)
+ else:
+ raise ValueError("Invalid model_handler specification.")
+
+ def underlying_handler(self):
+ return self._handler
+
+ @staticmethod
+ def default_preprocess_fn():
+ raise ValueError(
+ 'Handler does not implement a default preprocess '
+ 'method. Please define a preprocessing method using the '
+ '\'preprocess\' tag.')
+
+ def _preprocess_fn_internal(self):
+ return lambda row: (row, self._preprocess_fn(row))
+
+ @staticmethod
+ def default_postprocess_fn():
+ return lambda x: x
+
+ def _postprocess_fn_internal(self):
+ return lambda result: (result[0], self._postprocess_fn(result[1]))
+
+ @staticmethod
+ def validate(model_handler_spec):
+ raise NotImplementedError(type(ModelHandlerProvider))
+
+ @classmethod
+ def register_handler_type(cls, type_name):
+ def apply(constructor):
+ cls.handler_types[type_name] = constructor
+ return constructor
+
+ return apply
+
+ @classmethod
+ def create_handler(cls, model_handler_spec) -> "ModelHandlerProvider":
+ typ = model_handler_spec['type']
+ config = model_handler_spec['config']
+ try:
+ result = cls.handler_types[typ](**config)
+ if not hasattr(result, 'to_json'):
+ result.to_json = lambda: model_handler_spec
+ return result
+ except Exception as exn:
+ raise ValueError(
+ f'Unable to instantiate model handler of type {typ}. {exn}')
+
+
[email protected]_handler_type('VertexAIModelHandlerJSON')
+class VertexAIModelHandlerJSONProvider(ModelHandlerProvider):
+ def __init__(
+ self,
+ endpoint_id: str,
+ project: str,
+ location: str,
+ preprocess: Dict[str, str],
+ experiment: Optional[str] = None,
+ network: Optional[str] = None,
+ private: bool = False,
+ min_batch_size: Optional[int] = None,
+ max_batch_size: Optional[int] = None,
+ max_batch_duration_secs: Optional[int] = None,
+ env_vars: Optional[Dict[str, Any]] = None,
+ postprocess: Optional[Dict[str, str]] = None):
+ """
+ ModelHandler for Vertex AI.
+
+ For example: ::
+
+ - type: RunInference
+ config:
+ inference_tag: 'my_inference'
+ model_handler:
+ type: VertexAIModelHandlerJSON
+ config:
+ endpoint_id: 9876543210
+ project: my-project
+ location: us-east1
+ preprocess:
+ callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}'
+
+ Args:
+ endpoint_id: the numerical ID of the Vertex AI endpoint to query.
+ project: the GCP project name where the endpoint is deployed.
+ location: the GCP location where the endpoint is deployed.
+ experiment: Experiment label to apply to the
+ queries. See
+
https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments
+ for more information.
+ network: The full name of the Compute Engine
+ network the endpoint is deployed on; used for private
+ endpoints. The network or subnetwork Dataflow pipeline
+ option must be set and match this network for pipeline
+ execution.
+ Ex: "projects/12345/global/networks/myVPC"
+ private: If the deployed Vertex AI endpoint is
+ private, set to true. Requires a network to be provided
+ as well.
+ min_batch_size: The minimum batch size to use when batching
+ inputs.
+ max_batch_size: The maximum batch size to use when batching
+ inputs.
+ max_batch_duration_secs: The maximum amount of time to buffer
+ a batch before emitting; used in streaming contexts.
+ env_vars: Environment variables.
+ preprocess: A python callable, defined either inline, or using a file,
+ that is invoked on the input row before sending to the model to be
+ loaded by this ModelHandler. This parameter is required by the
+ `VertexAIModelHandlerJSON` ModelHandler.
+ postprocess: A python callable, defined either inline, or using a file,
+ that is invoked on the PredictionResult output by the ModelHandler
+ before parsing into the output Beam Row under the field name defined
+ by the inference_tag.
+ """
+
+ try:
+ from apache_beam.ml.inference.vertex_ai_inference import
VertexAIModelHandlerJSON
+ except ImportError:
+ raise ValueError(
+ 'Unable to import VertexAIModelHandlerJSON. Please '
+ 'install gcp dependencies: `pip install apache_beam[gcp]`')
+
+ _handler = VertexAIModelHandlerJSON(
+ endpoint_id=str(endpoint_id),
+ project=project,
+ location=location,
+ experiment=experiment,
+ network=network,
+ private=private,
+ min_batch_size=min_batch_size,
+ max_batch_size=max_batch_size,
+ max_batch_duration_secs=max_batch_duration_secs,
+ env_vars=env_vars or {})
+
+ super().__init__(_handler, preprocess, postprocess)
+
+ @staticmethod
+ def validate(model_handler_spec):
+ pass
+
+ def inference_output_type(self):
+ return RowTypeConstraint.from_fields([('example', Any), ('inference', Any),
Review Comment:
We don't generally today, but we could for a very limited set of handlers.
The ones I can think of with predictable output types are hugging face
pipelines and vLLM; the rest are all dependent on the model.
I think this is probably worth doing when we can, it probably requires some
slight modification to the PredictionResult type, though, and might be worth
scoping into a follow on PR.
##########
sdks/python/apache_beam/yaml/yaml_ml.py:
##########
@@ -33,11 +40,411 @@
tft = None # type: ignore
+class ModelHandlerProvider:
+ handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {}
+
+ def __init__(
+ self,
+ handler,
+ preprocess: Optional[Dict[str, str]] = None,
+ postprocess: Optional[Dict[str, str]] = None):
+ self._handler = handler
+ self._preprocess_fn = self.parse_processing_transform(
+ preprocess, 'preprocess') or self.default_preprocess_fn()
+ self._postprocess_fn = self.parse_processing_transform(
+ postprocess, 'postprocess') or self.default_postprocess_fn()
+
+ def inference_output_type(self):
+ return Any
+
+ @staticmethod
+ def parse_processing_transform(processing_transform, typ):
+ def _parse_config(callable=None, path=None, name=None):
+ if callable and (path or name):
+ raise ValueError(
+ f"Cannot specify 'callable' with 'path' and 'name' for {typ} "
+ f"function.")
+ if path and name:
+ return python_callable.PythonCallableWithSource.load_from_script(
+ FileSystems.open(path).read().decode(), name)
+ elif callable:
+ return python_callable.PythonCallableWithSource(callable)
+ else:
+ raise ValueError(
+ f"Must specify one of 'callable' or 'path' and 'name' for {typ} "
+ f"function.")
+
+ if processing_transform:
+ if isinstance(processing_transform, dict):
+ return _parse_config(**processing_transform)
+ else:
+ raise ValueError("Invalid model_handler specification.")
+
+ def underlying_handler(self):
+ return self._handler
+
+ @staticmethod
+ def default_preprocess_fn():
+ raise ValueError(
+ 'Handler does not implement a default preprocess '
+ 'method. Please define a preprocessing method using the '
+ '\'preprocess\' tag.')
+
+ def _preprocess_fn_internal(self):
+ return lambda row: (row, self._preprocess_fn(row))
+
+ @staticmethod
+ def default_postprocess_fn():
+ return lambda x: x
+
+ def _postprocess_fn_internal(self):
+ return lambda result: (result[0], self._postprocess_fn(result[1]))
+
+ @staticmethod
+ def validate(model_handler_spec):
+ raise NotImplementedError(type(ModelHandlerProvider))
+
+ @classmethod
+ def register_handler_type(cls, type_name):
+ def apply(constructor):
+ cls.handler_types[type_name] = constructor
+ return constructor
+
+ return apply
+
+ @classmethod
+ def create_handler(cls, model_handler_spec) -> "ModelHandlerProvider":
+ typ = model_handler_spec['type']
+ config = model_handler_spec['config']
+ try:
+ result = cls.handler_types[typ](**config)
+ if not hasattr(result, 'to_json'):
+ result.to_json = lambda: model_handler_spec
+ return result
+ except Exception as exn:
+ raise ValueError(
+ f'Unable to instantiate model handler of type {typ}. {exn}')
+
+
[email protected]_handler_type('VertexAIModelHandlerJSON')
+class VertexAIModelHandlerJSONProvider(ModelHandlerProvider):
+ def __init__(
+ self,
+ endpoint_id: str,
+ project: str,
+ location: str,
+ preprocess: Dict[str, str],
+ experiment: Optional[str] = None,
+ network: Optional[str] = None,
+ private: bool = False,
+ min_batch_size: Optional[int] = None,
+ max_batch_size: Optional[int] = None,
+ max_batch_duration_secs: Optional[int] = None,
+ env_vars: Optional[Dict[str, Any]] = None,
+ postprocess: Optional[Dict[str, str]] = None):
+ """
+ ModelHandler for Vertex AI.
+
+ For example: ::
+
+ - type: RunInference
+ config:
+ inference_tag: 'my_inference'
+ model_handler:
+ type: VertexAIModelHandlerJSON
+ config:
+ endpoint_id: 9876543210
+ project: my-project
+ location: us-east1
+ preprocess:
+ callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}'
+
+ Args:
+ endpoint_id: the numerical ID of the Vertex AI endpoint to query.
+ project: the GCP project name where the endpoint is deployed.
+ location: the GCP location where the endpoint is deployed.
+ experiment: Experiment label to apply to the
+ queries. See
+
https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments
+ for more information.
+ network: The full name of the Compute Engine
+ network the endpoint is deployed on; used for private
+ endpoints. The network or subnetwork Dataflow pipeline
+ option must be set and match this network for pipeline
+ execution.
+ Ex: "projects/12345/global/networks/myVPC"
+ private: If the deployed Vertex AI endpoint is
+ private, set to true. Requires a network to be provided
+ as well.
+ min_batch_size: The minimum batch size to use when batching
+ inputs.
+ max_batch_size: The maximum batch size to use when batching
+ inputs.
+ max_batch_duration_secs: The maximum amount of time to buffer
+ a batch before emitting; used in streaming contexts.
+ env_vars: Environment variables.
+ preprocess: A python callable, defined either inline, or using a file,
Review Comment:
Nit: args don't match the order defined in the function, and specifically
you have optional args before required args.
##########
sdks/python/apache_beam/yaml/yaml_ml.py:
##########
@@ -33,11 +40,411 @@
tft = None # type: ignore
+class ModelHandlerProvider:
+ handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {}
+
+ def __init__(
+ self,
+ handler,
+ preprocess: Optional[Dict[str, str]] = None,
+ postprocess: Optional[Dict[str, str]] = None):
+ self._handler = handler
+ self._preprocess_fn = self.parse_processing_transform(
+ preprocess, 'preprocess') or self.default_preprocess_fn()
+ self._postprocess_fn = self.parse_processing_transform(
+ postprocess, 'postprocess') or self.default_postprocess_fn()
+
+ def inference_output_type(self):
+ return Any
+
+ @staticmethod
+ def parse_processing_transform(processing_transform, typ):
+ def _parse_config(callable=None, path=None, name=None):
+ if callable and (path or name):
+ raise ValueError(
+ f"Cannot specify 'callable' with 'path' and 'name' for {typ} "
+ f"function.")
+ if path and name:
+ return python_callable.PythonCallableWithSource.load_from_script(
+ FileSystems.open(path).read().decode(), name)
+ elif callable:
+ return python_callable.PythonCallableWithSource(callable)
+ else:
+ raise ValueError(
+ f"Must specify one of 'callable' or 'path' and 'name' for {typ} "
+ f"function.")
+
+ if processing_transform:
+ if isinstance(processing_transform, dict):
+ return _parse_config(**processing_transform)
+ else:
+ raise ValueError("Invalid model_handler specification.")
+
+ def underlying_handler(self):
+ return self._handler
+
+ @staticmethod
+ def default_preprocess_fn():
+ raise ValueError(
+ 'Handler does not implement a default preprocess '
+ 'method. Please define a preprocessing method using the '
+ '\'preprocess\' tag.')
Review Comment:
I'm not sure how actionable people will find this - maybe we can add more
information like: "This is required in most cases because... For an example
preprocess method, see `VertexAIModelHandlerJSONProvider`".
##########
sdks/python/apache_beam/yaml/yaml_ml.py:
##########
@@ -33,11 +40,419 @@
tft = None # type: ignore
+def normalize_ml(spec):
+ if spec['type'] == 'RunInference':
+ config = spec.get('config')
+ for required in ('model_handler', ):
+ if required not in config:
+ raise ValueError(
+ f'Missing {required} parameter in RunInference config '
+ f'at line {SafeLineLoader.get_line(spec)}')
+ model_handler = config.get('model_handler')
+ if not isinstance(model_handler, dict):
+ raise ValueError(
+ 'Invalid model_handler specification at line '
+ f'{SafeLineLoader.get_line(spec)}. Expected '
+ f'dict but was {type(model_handler)}.')
+ for required in ('type', 'config'):
+ if required not in model_handler:
+ raise ValueError(
+ f'Missing {required} in model handler '
+ f'at line {SafeLineLoader.get_line(model_handler)}')
+ typ = model_handler['type']
+ extra_params = set(SafeLineLoader.strip_metadata(model_handler).keys()) - {
+ 'type', 'config'
+ }
+ if extra_params:
+ raise ValueError(
+ f'Unexpected parameters in model handler of type {typ} '
+ f'at line {SafeLineLoader.get_line(spec)}: {extra_params}')
+ model_handler_provider = ModelHandlerProvider.handler_types.get(typ, None)
+ if model_handler_provider:
+ model_handler_provider.validate(model_handler['config'])
+ else:
+ raise NotImplementedError(
+ f'Unknown model handler type: {typ} '
+ f'at line {SafeLineLoader.get_line(spec)}.')
+
+ return spec
+
+
+class ModelHandlerProvider:
+ handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {}
+
+ def __init__(
+ self, handler, preprocess: Callable = None, postprocess: Callable =
None):
+ self._handler = handler
+ self._preprocess = self.parse_processing_transform(
+ preprocess, 'preprocess') or self.preprocess_fn
+ self._postprocess = self.parse_processing_transform(
+ postprocess, 'postprocess') or self.postprocess_fn
+
+ def get_output_schema(self):
+ return Any
+
+ @staticmethod
+ def parse_processing_transform(processing_transform, typ):
+ def _parse_config(callable=None, path=None, name=None):
+ if callable and (path or name):
+ raise ValueError(
+ f"Cannot specify 'callable' with 'path' and 'name' for {typ} "
+ f"function.")
+ if path and name:
+ return python_callable.PythonCallableWithSource.load_from_script(
+ FileSystems.open(path).read().decode(), name)
+ elif callable:
+ return python_callable.PythonCallableWithSource(callable)
+ else:
+ raise ValueError(
+ f"Must specify one of 'callable' or 'path' and 'name' for {typ} "
+ f"function.")
+
+ if processing_transform:
+ if isinstance(processing_transform, dict):
+ return _parse_config(**processing_transform)
+ else:
+ raise ValueError("Invalid model_handler specification.")
+
+ def underlying_handler(self):
+ return self._handler
+
+ def preprocess_fn(self, row):
+ raise ValueError(
+ 'Handler does not implement a default preprocess '
+ 'method. Please define a preprocessing method using the '
+ '\'preprocess\' tag.')
+
+ def create_preprocess_fn(self):
+ return lambda row: (row, self._preprocess(row))
+
+ @staticmethod
+ def postprocess_fn(x):
+ return x
+
+ def create_postprocess_fn(self):
+ return lambda result: (result[0], self._postprocess(result[1]))
+
+ @staticmethod
+ def validate(model_handler_spec):
+ raise NotImplementedError(type(ModelHandlerProvider))
+
+ @classmethod
+ def register_handler_type(cls, type_name):
+ def apply(constructor):
+ cls.handler_types[type_name] = constructor
+ return constructor
+
+ return apply
+
+ @classmethod
+ def create_handler(cls, model_handler_spec) -> "ModelHandlerProvider":
+ typ = model_handler_spec['type']
+ config = model_handler_spec['config']
+ try:
+ result = cls.handler_types[typ](**config)
+ if not hasattr(result, 'to_json'):
+ result.to_json = lambda: model_handler_spec
+ return result
+ except Exception as exn:
+ raise ValueError(
+ f'Unable to instantiate model handler of type {typ}. {exn}')
+
+
[email protected]_handler_type('VertexAIModelHandlerJSON')
+class VertexAIModelHandlerJSONProvider(ModelHandlerProvider):
+ def __init__(
+ self,
+ endpoint_id: str,
+ endpoint_project: str,
+ endpoint_region: str,
+ experiment: Optional[str] = None,
+ network: Optional[str] = None,
+ private: bool = False,
+ min_batch_size: Optional[int] = None,
+ max_batch_size: Optional[int] = None,
+ max_batch_duration_secs: Optional[int] = None,
+ env_vars=None,
+ preprocess: Callable = None,
+ postprocess: Callable = None):
+ """ModelHandler for Vertex AI.
+
+ For example: ::
+
+ - type: RunInference
+ config:
+ inference_tag: 'my_inference'
+ model_handler:
+ type: VertexAIModelHandlerJSON
+ config:
+ endpoint_id: 9876543210
Review Comment:
Yeah, I think this is it unfortunately.
##########
sdks/python/apache_beam/yaml/yaml_ml.py:
##########
@@ -33,11 +40,411 @@
tft = None # type: ignore
+class ModelHandlerProvider:
+ handler_types: Dict[str, Callable[..., "ModelHandlerProvider"]] = {}
+
+ def __init__(
+ self,
+ handler,
+ preprocess: Optional[Dict[str, str]] = None,
+ postprocess: Optional[Dict[str, str]] = None):
+ self._handler = handler
+ self._preprocess_fn = self.parse_processing_transform(
+ preprocess, 'preprocess') or self.default_preprocess_fn()
+ self._postprocess_fn = self.parse_processing_transform(
+ postprocess, 'postprocess') or self.default_postprocess_fn()
+
+ def inference_output_type(self):
+ return Any
+
+ @staticmethod
+ def parse_processing_transform(processing_transform, typ):
+ def _parse_config(callable=None, path=None, name=None):
+ if callable and (path or name):
+ raise ValueError(
+ f"Cannot specify 'callable' with 'path' and 'name' for {typ} "
+ f"function.")
+ if path and name:
+ return python_callable.PythonCallableWithSource.load_from_script(
+ FileSystems.open(path).read().decode(), name)
+ elif callable:
+ return python_callable.PythonCallableWithSource(callable)
+ else:
+ raise ValueError(
+ f"Must specify one of 'callable' or 'path' and 'name' for {typ} "
+ f"function.")
+
+ if processing_transform:
+ if isinstance(processing_transform, dict):
+ return _parse_config(**processing_transform)
+ else:
+ raise ValueError("Invalid model_handler specification.")
+
+ def underlying_handler(self):
+ return self._handler
+
+ @staticmethod
+ def default_preprocess_fn():
+ raise ValueError(
+ 'Handler does not implement a default preprocess '
+ 'method. Please define a preprocessing method using the '
+ '\'preprocess\' tag.')
+
+ def _preprocess_fn_internal(self):
+ return lambda row: (row, self._preprocess_fn(row))
+
+ @staticmethod
+ def default_postprocess_fn():
+ return lambda x: x
+
+ def _postprocess_fn_internal(self):
+ return lambda result: (result[0], self._postprocess_fn(result[1]))
+
+ @staticmethod
+ def validate(model_handler_spec):
+ raise NotImplementedError(type(ModelHandlerProvider))
+
+ @classmethod
+ def register_handler_type(cls, type_name):
+ def apply(constructor):
+ cls.handler_types[type_name] = constructor
+ return constructor
+
+ return apply
+
+ @classmethod
+ def create_handler(cls, model_handler_spec) -> "ModelHandlerProvider":
+ typ = model_handler_spec['type']
+ config = model_handler_spec['config']
+ try:
+ result = cls.handler_types[typ](**config)
+ if not hasattr(result, 'to_json'):
+ result.to_json = lambda: model_handler_spec
+ return result
+ except Exception as exn:
+ raise ValueError(
+ f'Unable to instantiate model handler of type {typ}. {exn}')
+
+
[email protected]_handler_type('VertexAIModelHandlerJSON')
+class VertexAIModelHandlerJSONProvider(ModelHandlerProvider):
+ def __init__(
+ self,
+ endpoint_id: str,
+ project: str,
+ location: str,
+ preprocess: Dict[str, str],
+ experiment: Optional[str] = None,
+ network: Optional[str] = None,
+ private: bool = False,
+ min_batch_size: Optional[int] = None,
+ max_batch_size: Optional[int] = None,
+ max_batch_duration_secs: Optional[int] = None,
+ env_vars: Optional[Dict[str, Any]] = None,
+ postprocess: Optional[Dict[str, str]] = None):
+ """
+ ModelHandler for Vertex AI.
+
+ For example: ::
+
+ - type: RunInference
+ config:
+ inference_tag: 'my_inference'
+ model_handler:
+ type: VertexAIModelHandlerJSON
+ config:
+ endpoint_id: 9876543210
+ project: my-project
+ location: us-east1
+ preprocess:
+ callable: 'lambda x: {"prompt": x.prompt, "max_tokens": 50}'
+
+ Args:
+ endpoint_id: the numerical ID of the Vertex AI endpoint to query.
+ project: the GCP project name where the endpoint is deployed.
+ location: the GCP location where the endpoint is deployed.
+ experiment: Experiment label to apply to the
+ queries. See
+
https://cloud.google.com/vertex-ai/docs/experiments/intro-vertex-ai-experiments
+ for more information.
+ network: The full name of the Compute Engine
+ network the endpoint is deployed on; used for private
+ endpoints. The network or subnetwork Dataflow pipeline
+ option must be set and match this network for pipeline
+ execution.
+ Ex: "projects/12345/global/networks/myVPC"
+ private: If the deployed Vertex AI endpoint is
+ private, set to true. Requires a network to be provided
+ as well.
+ min_batch_size: The minimum batch size to use when batching
+ inputs.
+ max_batch_size: The maximum batch size to use when batching
+ inputs.
+ max_batch_duration_secs: The maximum amount of time to buffer
+ a batch before emitting; used in streaming contexts.
+ env_vars: Environment variables.
+ preprocess: A python callable, defined either inline, or using a file,
+ that is invoked on the input row before sending to the model to be
+ loaded by this ModelHandler. This parameter is required by the
+ `VertexAIModelHandlerJSON` ModelHandler.
Review Comment:
It would be great to add more information on how this is used. E.g. "Vertex
AI expects requests formatted as json. The exact format of this depends on the
model, but here is an example..."
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]