shub-kris commented on code in PR #23497:
URL: https://github.com/apache/beam/pull/23497#discussion_r991006900


##########
sdks/python/apache_beam/examples/inference/anomaly_detection/anomaly_detection_pipeline/pipeline/transformations.py:
##########
@@ -0,0 +1,192 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This file contains the transformations and utility functions for
+the anomaly_detection pipeline."""
+import json
+
+import numpy as np
+
+import apache_beam as beam
+import config as cfg
+import hdbscan
+import torch
+import yagmail
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.sklearn_inference import SklearnModelHandlerNumpy
+from apache_beam.ml.inference.sklearn_inference import _validate_inference_args
+from transformers import AutoTokenizer
+from transformers import DistilBertModel
+
+# [START tokenization]
+Tokenizer = AutoTokenizer.from_pretrained(cfg.TOKENIZER_NAME)
+
+
+def tokenize_sentence(input_dict):
+  """
+    It takes a dictionary with a text and an id, tokenizes the text, and
+    returns a tuple of the text and id and the tokenized text
+
+    Args:
+      input_dict: a dictionary with the text and id of the sentence
+
+    Returns:
+      A tuple of the text and id, and a dictionary of the tokens.
+    """
+  text, uid = input_dict["text"], input_dict["id"]
+  tokens = Tokenizer([text], padding=True, truncation=True, 
return_tensors="pt")
+  tokens = {key: torch.squeeze(val) for key, val in tokens.items()}
+  return (text, uid), tokens
+
+
+# [END tokenization]
+
+
+# [START DistilBertModelWrapper]
+class ModelWrapper(DistilBertModel):
+  """Wrapper to DistilBertModel to get embeddings when calling
+    forward function."""
+  def forward(self, **kwargs):
+    output = super().forward(**kwargs)
+    sentence_embedding = (
+        self.mean_pooling(output,
+                          kwargs["attention_mask"]).detach().cpu().numpy())
+    return sentence_embedding
+
+  # Mean Pooling - Take attention mask into account for correct averaging
+  def mean_pooling(self, model_output, attention_mask):
+    """
+        The function calculates the mean of token embeddings
+
+        Args:
+          model_output: The output of the model.
+          attention_mask: This is a tensor that contains 1s for all input 
tokens and
+          0s for all padding tokens.
+
+        Returns:
+          The mean of the token embeddings.
+        """
+    token_embeddings = model_output[
+        0]  # First element of model_output contains all token embeddings
+    input_mask_expanded = (
+        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float())
+    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
+        input_mask_expanded.sum(1), min=1e-9)
+
+
+# [END DistilBertModelWrapper]
+
+
+# [START CustomSklearnModelHandlerNumpy]
+class CustomSklearnModelHandlerNumpy(SklearnModelHandlerNumpy):
+  # Can be removed once: https://github.com/apache/beam/issues/21863 is fixed
+  def batch_elements_kwargs(self):
+    """Limit batch size to 1 for inference"""
+    return {"max_batch_size": 1}
+
+  def run_inference(self, batch, model, inference_args=None):

Review Comment:
   Sure



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to