liferoad commented on code in PR #27399:
URL: https://github.com/apache/beam/pull/27399#discussion_r1273960409


##########
sdks/python/apache_beam/ml/inference/huggingface_inference.py:
##########
@@ -483,6 +499,129 @@ def share_model_across_processes(self) -> bool:
   def get_metrics_namespace(self) -> str:
     """
     Returns:
-        A namespace for metrics collected by the RunInference transform.
+       A namespace for metrics collected by the RunInference transform.
+    """
+    return 'BeamML_HuggingFaceModelHandler_Tensor'
+
+
+def _convert_to_result(
+    batch: Iterable,
+    predictions: Union[Iterable, Dict[Any, Iterable]],
+    model_id: Optional[str] = None,
+) -> Iterable[PredictionResult]:
+  return [
+      PredictionResult(x, y, model_id) for x, y in zip(batch, [predictions])
+  ]
+
+
+def _default_pipeline_inference_fn(
+    batch, model, inference_args) -> Iterable[PredictionResult]:
+  predicitons = model(batch, **inference_args)
+  return predicitons
+
+
+class HuggingFacePipelineModelHandler(ModelHandler[str,
+                                                   PredictionResult,
+                                                   Pipeline]):
+  def __init__(
+      self,
+      task: str = "",
+      model=None,
+      *,
+      inference_fn: PipelineInferenceFn = _default_pipeline_inference_fn,
+      load_model_args: Optional[Dict[str, Any]] = None,
+      inference_args: Optional[Dict[str, Any]] = None,
+      min_batch_size: Optional[int] = None,
+      max_batch_size: Optional[int] = None,
+      large_model: bool = False,
+      **kwargs):
+    """
+    Implementation of the ModelHandler interface for Hugging Face Pipelines.
+
+    **Note:** To specify which device to use (CPU/GPU),
+    use the load_model_args with key-value as you would do in the usual
+    Hugging Face pipeline. Ex: load_model_args={'device':0})
+
+    Example Usage model::
+      pcoll | RunInference(HuggingFacePipelineModelHandler(
+        task="fill-mask"))
+
+    Args:
+      task (str): task supported by HuggingFace Pipelines.

Review Comment:
   Better provide the link for users to check what tasks are: 
https://huggingface.co/transformers/v4.12.5/_modules/transformers/pipelines.html#pipeline
 has the current list. We could support both enum and str.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to