chamikaramj commented on code in PR #24656: URL: https://github.com/apache/beam/pull/24656#discussion_r1052469265
########## sdks/python/apache_beam/examples/inference/multi_language/expansion_service/run_inference_expansion.py: ########## @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import argparse +import logging +import signal +import sys +import typing + +import grpc + +import apache_beam as beam +from apache_beam.coders import RowCoder +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor +from apache_beam.pipeline import PipelineOptions +from apache_beam.portability.api import beam_artifact_api_pb2_grpc +from apache_beam.portability.api import beam_expansion_api_pb2_grpc +from apache_beam.portability.api import external_transforms_pb2 +from apache_beam.runners.portability import artifact_service +from apache_beam.runners.portability import expansion_service +from apache_beam.transforms import fully_qualified_named_transform +from apache_beam.transforms import ptransform +from apache_beam.transforms.environments import PyPIArtifactRegistry +from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder +from apache_beam.utils import thread_pool_executor +from transformers import BertConfig +from transformers import BertForMaskedLM +from transformers import BertTokenizer + +# This script provides an expansion service for a run inference transform +# with pre and post processing. +# The model used is a BertLM, base uncased model. +_LOGGER = logging.getLogger(__name__) + +# A transform that runs inference on a Bertmodel. +TEST_RUN_BERT_URN = "beam:transforms:xlang:test:run_bert" + + [email protected]_urn(TEST_RUN_BERT_URN, None) +class RunInferenceTransform(ptransform.PTransform): + class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor): + """Wrapper to PytorchModelHandler to limit batch size to 1. + The tokenized strings generated from BertTokenizer may have different + lengths, which doesn't work with torch.stack() in current RunInference + implementation since stack() requires tensors to be the same size. + Restricting max_batch_size to 1 means there is only 1 example per + `batch` in the run_inference() call. + """ + def batch_elements_kwargs(self): + return {'max_batch_size': 1} + + class Preprocess(beam.DoFn): + def __init__(self, tokenizer): + # self._model_name = model_name + logging.info('Starting Preprocess') + # self._tokenizer = BertTokenizer.from_pretrained(self._model_name) + self._tokenizer = tokenizer + logging.info('Tokenizer loaded') + + def process(self, text: str): + import torch + if len(text.strip()) > 0: + logging.info('Preprocessing Line: %s', text) + text_list = text.split() + masked_text = ' '.join(text_list[:-2] + ['[MASK]', text_list[-1]]) + tokens = self._tokenizer(masked_text, return_tensors='pt') + tokens = {key: torch.squeeze(val) for key, val in tokens.items()} + return [(text, tokens)] + + class Postprocess(beam.DoFn): + def __init__(self, bert_tokenizer): + logging.info('Starting Postprocess') + self.bert_tokenizer = bert_tokenizer + + def process(self, element: typing.Tuple[str, PredictionResult]) \ + -> typing.Iterable[str]: + text, prediction_result = element + inputs = prediction_result.example + logits = prediction_result.inference['logits'] + mask_token_index = ( + inputs['input_ids'] == self.bert_tokenizer.mask_token_id).nonzero( + as_tuple=True)[0] + predicted_token_id = logits[mask_token_index].argmax(axis=-1) + decoded_word = self.bert_tokenizer.decode(predicted_token_id) + text = text.replace('.', '') + yield text + '\n Predicted word: ' + decoded_word.upper() + + def __init__(self, model): + self._model = model + # can also save the model config and tokenizer in gcs and load in + self._model_config = BertConfig.from_pretrained(self._model) + self._tokenizer = BertTokenizer.from_pretrained(self._model) + self._model_handler = self.PytorchModelHandlerKeyedTensorWrapper( + state_dict_path=( + "gs://apache-beam-x-lang-testing/input/" + "bert-model/bert-base-uncased.pth"), + model_class=BertForMaskedLM, + model_params={'config': self._model_config}, + device='cuda:0') + + def expand(self, pcoll): + return ( + pcoll + | 'Preprocess' >> beam.ParDo(self.Preprocess(self._tokenizer)) + | 'Inference' >> RunInference(KeyedModelHandler(self._model_handler)) + | 'Postprocess' >> beam.ParDo(self.Postprocess( + self._tokenizer)).with_input_types(typing.Iterable[str])) + + def to_runner_api_parameter(self, unused_context): + return TEST_RUN_BERT_URN, ImplicitSchemaPayloadBuilder( + {'model': self._model}).payload() + + @staticmethod + def from_runner_api_parameter(unused_ptransform, payload, unused_context): + return RunInferenceTransform(parse_string_payload(payload)['model']) + + [email protected]_urn('payload', bytes) +class PayloadTransform(ptransform.PTransform): + def __init__(self, payload): + self._payload = payload + + def expand(self, pcoll): + return pcoll | beam.Map(lambda x, s: x + s, self._payload) + + def to_runner_api_parameter(self, unused_context): + return b'payload', self._payload.encode('ascii') + + @staticmethod + def from_runner_api_parameter(unused_ptransform, payload, unused_context): + return PayloadTransform(payload.decode('ascii')) + + +def parse_string_payload(input_byte): + payload = external_transforms_pb2.ExternalConfigurationPayload() + payload.ParseFromString(input_byte) + + return RowCoder(payload.schema).decode(payload.payload)._asdict() + + +server = None + + +def cleanup(unused_signum, unused_frame): + _LOGGER.info('Shutting down expansion service.') + server.stop(None) + + +def main(unused_argv): + PyPIArtifactRegistry.register_artifact('beautifulsoup4', '>=4.9,<5.0 \n') Review Comment: It's not available yet but the release is ongoing so probably soon. With current release "withExtraPackages" API does not support specifying local packages (so only PyPI packages) so you'll have to setup a custom expansion service with your code and provide a custom container for execution. I think it's good to go with that option now and provide a simplified mechanism later for users on 2.43.0 or later. WDYT ? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
