chamikaramj commented on code in PR #24656: URL: https://github.com/apache/beam/pull/24656#discussion_r1050036835
########## sdks/python/apache_beam/examples/inference/multi_language/expansion_service/run_inference_expansion.py: ########## @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import argparse +import logging +import signal +import sys +import typing + +import grpc + +import apache_beam as beam +from apache_beam.coders import RowCoder +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor +from apache_beam.pipeline import PipelineOptions +from apache_beam.portability.api import beam_artifact_api_pb2_grpc +from apache_beam.portability.api import beam_expansion_api_pb2_grpc +from apache_beam.portability.api import external_transforms_pb2 +from apache_beam.runners.portability import artifact_service +from apache_beam.runners.portability import expansion_service +from apache_beam.transforms import fully_qualified_named_transform +from apache_beam.transforms import ptransform +from apache_beam.transforms.environments import PyPIArtifactRegistry +from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder +from apache_beam.utils import thread_pool_executor +from transformers import BertConfig +from transformers import BertForMaskedLM +from transformers import BertTokenizer + +# This script provides an expansion service for a run inference transform +# with pre and post processing. +# The model used is a BertLM, base uncased model. +_LOGGER = logging.getLogger(__name__) + +# A transform that runs inference on a Bertmodel. +TEST_RUN_BERT_URN = "beam:transforms:xlang:test:run_bert" + + [email protected]_urn(TEST_RUN_BERT_URN, None) +class RunInferenceTransform(ptransform.PTransform): + class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor): + """Wrapper to PytorchModelHandler to limit batch size to 1. + The tokenized strings generated from BertTokenizer may have different + lengths, which doesn't work with torch.stack() in current RunInference + implementation since stack() requires tensors to be the same size. + Restricting max_batch_size to 1 means there is only 1 example per + `batch` in the run_inference() call. + """ + def batch_elements_kwargs(self): + return {'max_batch_size': 1} + + class Preprocess(beam.DoFn): + def __init__(self, tokenizer): + # self._model_name = model_name + logging.info('Starting Preprocess') + # self._tokenizer = BertTokenizer.from_pretrained(self._model_name) + self._tokenizer = tokenizer + logging.info('Tokenizer loaded') + + def process(self, text: str): + import torch + if len(text.strip()) > 0: + logging.info('Preprocessing Line: %s', text) + text_list = text.split() + masked_text = ' '.join(text_list[:-2] + ['[MASK]', text_list[-1]]) + tokens = self._tokenizer(masked_text, return_tensors='pt') + tokens = {key: torch.squeeze(val) for key, val in tokens.items()} + return [(text, tokens)] + + class Postprocess(beam.DoFn): + def __init__(self, bert_tokenizer): + logging.info('Starting Postprocess') + self.bert_tokenizer = bert_tokenizer + + def process(self, element: typing.Tuple[str, PredictionResult]) \ + -> typing.Iterable[str]: + text, prediction_result = element + inputs = prediction_result.example + logits = prediction_result.inference['logits'] + mask_token_index = ( + inputs['input_ids'] == self.bert_tokenizer.mask_token_id).nonzero( + as_tuple=True)[0] + predicted_token_id = logits[mask_token_index].argmax(axis=-1) + decoded_word = self.bert_tokenizer.decode(predicted_token_id) + text = text.replace('.', '') + yield text + '\n Predicted word: ' + decoded_word.upper() + + def __init__(self, model): + self._model = model + # can also save the model config and tokenizer in gcs and load in + self._model_config = BertConfig.from_pretrained(self._model) + self._tokenizer = BertTokenizer.from_pretrained(self._model) + self._model_handler = self.PytorchModelHandlerKeyedTensorWrapper( + state_dict_path=( + "gs://apache-beam-x-lang-testing/input/" + "bert-model/bert-base-uncased.pth"), + model_class=BertForMaskedLM, + model_params={'config': self._model_config}, + device='cuda:0') + + def expand(self, pcoll): + return ( + pcoll + | 'Preprocess' >> beam.ParDo(self.Preprocess(self._tokenizer)) + | 'Inference' >> RunInference(KeyedModelHandler(self._model_handler)) + | 'Postprocess' >> beam.ParDo(self.Postprocess( + self._tokenizer)).with_input_types(typing.Iterable[str])) + + def to_runner_api_parameter(self, unused_context): + return TEST_RUN_BERT_URN, ImplicitSchemaPayloadBuilder( + {'model': self._model}).payload() + + @staticmethod + def from_runner_api_parameter(unused_ptransform, payload, unused_context): + return RunInferenceTransform(parse_string_payload(payload)['model']) + + [email protected]_urn('payload', bytes) +class PayloadTransform(ptransform.PTransform): + def __init__(self, payload): + self._payload = payload + + def expand(self, pcoll): + return pcoll | beam.Map(lambda x, s: x + s, self._payload) + + def to_runner_api_parameter(self, unused_context): + return b'payload', self._payload.encode('ascii') + + @staticmethod + def from_runner_api_parameter(unused_ptransform, payload, unused_context): + return PayloadTransform(payload.decode('ascii')) + + +def parse_string_payload(input_byte): + payload = external_transforms_pb2.ExternalConfigurationPayload() + payload.ParseFromString(input_byte) + + return RowCoder(payload.schema).decode(payload.payload)._asdict() + + +server = None + + +def cleanup(unused_signum, unused_frame): + _LOGGER.info('Shutting down expansion service.') + server.stop(None) + + +def main(unused_argv): + PyPIArtifactRegistry.register_artifact('beautifulsoup4', '>=4.9,<5.0 \n') Review Comment: You can register artifact directly using the API when using the PythonExternalTransform API. https://github.com/apache/beam/blob/62b7247a28e1f9f4fa23cfbe202f939a5634caa8/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java#L288 In addition to PyPi packages, local packages are also supported starting with Beam 2.44.0. ########## sdks/python/apache_beam/examples/inference/multi_language/last_word_prediction/src/main/java/org/MultiLangRunInference.java: ########## @@ -0,0 +1,93 @@ +package org; +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.runners.core.construction.External; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.Validation.Required; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.PDone; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class MultiLangRunInference { + public interface MultiLangueageOptions extends PipelineOptions { + + @Description("Path to an input file that contains labels and pixels to feed into the model") + @Required + String getInputFile(); + + void setInputFile(String value); + + @Description("Path to an input file that contains labels and pixels to feed into the model") + @Required + String getOutputFile(); + + void setOutputFile(String value); + } + + private static byte[] toStringPayloadBytes(String model) { + Row configRow = Row.withSchema(Schema.of(Field.of("model", FieldType.STRING))) + .withFieldValue("model", model) + .build(); + + ByteStringOutputStream outputStream = new ByteStringOutputStream(); + + try { + RowCoder.of(configRow.getSchema()).encode(configRow, outputStream); + } catch (IOException e) { + throw new RuntimeException(e); + } + ExternalTransforms.ExternalConfigurationPayload payload = ExternalTransforms.ExternalConfigurationPayload + .newBuilder() + .setSchema(SchemaTranslation.schemaToProto(configRow.getSchema(), false)) + .setPayload(outputStream.toByteString()) + .build(); + return payload.toByteArray(); + } + + public static void main(String[] args) { + + final String TEST_RUN_BERT_URN = "beam:transforms:xlang:test:run_bert"; + MultiLangueageOptions options = PipelineOptionsFactory.fromArgs(args).withValidation() + .as(MultiLangueageOptions.class); + + Pipeline pipeline = Pipeline.create(options); + PCollection<String> predictions = pipeline.apply("Read Input", TextIO.read().from(options.getInputFile())) + .apply("Run Inference" ,External.of(TEST_RUN_BERT_URN, toStringPayloadBytes("bert-base-uncased"), "localhost:12345")); Review Comment: Instead of using the low level "External.of" API for x-lang, please update the example to use the high level PythoneExternalTransform API. This is the recommended way to use Python transforms from Java. https://github.com/apache/beam/blob/master/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/PythonExternalTransform.java It's documented below. https://beam.apache.org/documentation/programming-guide/#1312-creating-cross-language-python-transforms ########## sdks/python/apache_beam/examples/inference/multi_language/expansion_service/run_inference_expansion.py: ########## @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import argparse +import logging +import signal +import sys +import typing + +import grpc + +import apache_beam as beam +from apache_beam.coders import RowCoder +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor +from apache_beam.pipeline import PipelineOptions +from apache_beam.portability.api import beam_artifact_api_pb2_grpc +from apache_beam.portability.api import beam_expansion_api_pb2_grpc +from apache_beam.portability.api import external_transforms_pb2 +from apache_beam.runners.portability import artifact_service +from apache_beam.runners.portability import expansion_service +from apache_beam.transforms import fully_qualified_named_transform +from apache_beam.transforms import ptransform +from apache_beam.transforms.environments import PyPIArtifactRegistry +from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder +from apache_beam.utils import thread_pool_executor +from transformers import BertConfig +from transformers import BertForMaskedLM +from transformers import BertTokenizer + +# This script provides an expansion service for a run inference transform +# with pre and post processing. +# The model used is a BertLM, base uncased model. +_LOGGER = logging.getLogger(__name__) + +# A transform that runs inference on a Bertmodel. +TEST_RUN_BERT_URN = "beam:transforms:xlang:test:run_bert" + + [email protected]_urn(TEST_RUN_BERT_URN, None) +class RunInferenceTransform(ptransform.PTransform): + class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor): + """Wrapper to PytorchModelHandler to limit batch size to 1. + The tokenized strings generated from BertTokenizer may have different + lengths, which doesn't work with torch.stack() in current RunInference + implementation since stack() requires tensors to be the same size. + Restricting max_batch_size to 1 means there is only 1 example per + `batch` in the run_inference() call. + """ + def batch_elements_kwargs(self): + return {'max_batch_size': 1} + + class Preprocess(beam.DoFn): + def __init__(self, tokenizer): + # self._model_name = model_name + logging.info('Starting Preprocess') + # self._tokenizer = BertTokenizer.from_pretrained(self._model_name) + self._tokenizer = tokenizer + logging.info('Tokenizer loaded') + + def process(self, text: str): + import torch + if len(text.strip()) > 0: + logging.info('Preprocessing Line: %s', text) + text_list = text.split() + masked_text = ' '.join(text_list[:-2] + ['[MASK]', text_list[-1]]) + tokens = self._tokenizer(masked_text, return_tensors='pt') + tokens = {key: torch.squeeze(val) for key, val in tokens.items()} + return [(text, tokens)] + + class Postprocess(beam.DoFn): + def __init__(self, bert_tokenizer): + logging.info('Starting Postprocess') + self.bert_tokenizer = bert_tokenizer + + def process(self, element: typing.Tuple[str, PredictionResult]) \ + -> typing.Iterable[str]: + text, prediction_result = element + inputs = prediction_result.example + logits = prediction_result.inference['logits'] + mask_token_index = ( + inputs['input_ids'] == self.bert_tokenizer.mask_token_id).nonzero( + as_tuple=True)[0] + predicted_token_id = logits[mask_token_index].argmax(axis=-1) + decoded_word = self.bert_tokenizer.decode(predicted_token_id) + text = text.replace('.', '') + yield text + '\n Predicted word: ' + decoded_word.upper() + + def __init__(self, model): + self._model = model + # can also save the model config and tokenizer in gcs and load in + self._model_config = BertConfig.from_pretrained(self._model) + self._tokenizer = BertTokenizer.from_pretrained(self._model) + self._model_handler = self.PytorchModelHandlerKeyedTensorWrapper( + state_dict_path=( + "gs://apache-beam-x-lang-testing/input/" + "bert-model/bert-base-uncased.pth"), + model_class=BertForMaskedLM, + model_params={'config': self._model_config}, + device='cuda:0') + + def expand(self, pcoll): + return ( + pcoll + | 'Preprocess' >> beam.ParDo(self.Preprocess(self._tokenizer)) + | 'Inference' >> RunInference(KeyedModelHandler(self._model_handler)) + | 'Postprocess' >> beam.ParDo(self.Postprocess( + self._tokenizer)).with_input_types(typing.Iterable[str])) + + def to_runner_api_parameter(self, unused_context): + return TEST_RUN_BERT_URN, ImplicitSchemaPayloadBuilder( + {'model': self._model}).payload() + + @staticmethod + def from_runner_api_parameter(unused_ptransform, payload, unused_context): + return RunInferenceTransform(parse_string_payload(payload)['model']) + + [email protected]_urn('payload', bytes) +class PayloadTransform(ptransform.PTransform): + def __init__(self, payload): + self._payload = payload + + def expand(self, pcoll): + return pcoll | beam.Map(lambda x, s: x + s, self._payload) + + def to_runner_api_parameter(self, unused_context): + return b'payload', self._payload.encode('ascii') + + @staticmethod + def from_runner_api_parameter(unused_ptransform, payload, unused_context): + return PayloadTransform(payload.decode('ascii')) + + +def parse_string_payload(input_byte): + payload = external_transforms_pb2.ExternalConfigurationPayload() + payload.ParseFromString(input_byte) + + return RowCoder(payload.schema).decode(payload.payload)._asdict() + + +server = None + + +def cleanup(unused_signum, unused_frame): + _LOGGER.info('Shutting down expansion service.') + server.stop(None) + + +def main(unused_argv): Review Comment: I think this main should be removed and we should add instructions for just starting up the default expansion service using instructions here. https://beam.apache.org/documentation/sdks/java-multi-language-pipelines/#advanced-start-an-expansion-service We can just add Python code to a local package and include it via PythonExternalTransform.withExtraPackages API (starting with Beam 2.44.0). ########## sdks/python/apache_beam/examples/inference/multi_language/README.md: ########## @@ -0,0 +1,49 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> +## Setting up the expansion service +In order to start the python expansion service, run the following command: + +``` +python -m expansion_service.run_inference_expansion \ + --port=<port to host expansion service> \ Review Comment: You can select any available port (but users that use the Beam provided RunInference wrapper should not need this as I mentioned in the other comment). ########## sdks/python/apache_beam/examples/inference/multi_language/README.md: ########## @@ -0,0 +1,49 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> +## Setting up the expansion service +In order to start the python expansion service, run the following command: Review Comment: If you are using the Beam-provided RunInference Java wrapper [1] and use released Beam that would automatically startup an expansion service. In other cases, a custom expansion service will be needed currently. We should at least clarify this in the documentation. [1] https://github.com/apache/beam/blob/master/sdks/java/extensions/python/src/main/java/org/apache/beam/sdk/extensions/python/transforms/RunInference.java ########## sdks/python/apache_beam/examples/inference/multi_language/README.md: ########## @@ -0,0 +1,49 @@ +<!-- + Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +--> +## Setting up the expansion service +In order to start the python expansion service, run the following command: + +``` +python -m expansion_service.run_inference_expansion \ + --port=<port to host expansion service> \ + --env_config=<python container> +``` +If you use a custom python container, make sure it is publicly accessible for the workers to pull. Review Comment: I would just defer to following documentation. https://cloud.google.com/dataflow/docs/guides/using-custom-containers For authenticating into GCP following instructions should suffice. https://cloud.google.com/dataflow/docs/concepts/security-and-permissions For other systems, the container might have to include the secret I believe. ########## sdks/python/apache_beam/examples/inference/multi_language/last_word_prediction/src/main/java/org/MultiLangRunInference.java: ########## @@ -0,0 +1,93 @@ +package org; +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.runners.core.construction.External; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.Validation.Required; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.PDone; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class MultiLangRunInference { + public interface MultiLangueageOptions extends PipelineOptions { + + @Description("Path to an input file that contains labels and pixels to feed into the model") + @Required + String getInputFile(); + + void setInputFile(String value); + + @Description("Path to an input file that contains labels and pixels to feed into the model") + @Required + String getOutputFile(); + + void setOutputFile(String value); + } + + private static byte[] toStringPayloadBytes(String model) { + Row configRow = Row.withSchema(Schema.of(Field.of("model", FieldType.STRING))) + .withFieldValue("model", model) + .build(); + + ByteStringOutputStream outputStream = new ByteStringOutputStream(); + + try { + RowCoder.of(configRow.getSchema()).encode(configRow, outputStream); + } catch (IOException e) { + throw new RuntimeException(e); + } + ExternalTransforms.ExternalConfigurationPayload payload = ExternalTransforms.ExternalConfigurationPayload + .newBuilder() + .setSchema(SchemaTranslation.schemaToProto(configRow.getSchema(), false)) + .setPayload(outputStream.toByteString()) + .build(); + return payload.toByteArray(); + } + + public static void main(String[] args) { + + final String TEST_RUN_BERT_URN = "beam:transforms:xlang:test:run_bert"; + MultiLangueageOptions options = PipelineOptionsFactory.fromArgs(args).withValidation() + .as(MultiLangueageOptions.class); + + Pipeline pipeline = Pipeline.create(options); + PCollection<String> predictions = pipeline.apply("Read Input", TextIO.read().from(options.getInputFile())) + .apply("Run Inference" ,External.of(TEST_RUN_BERT_URN, toStringPayloadBytes("bert-base-uncased"), "localhost:12345")); Review Comment: This should also simplify your code quite a bit. ########## sdks/python/apache_beam/examples/inference/multi_language/last_word_prediction/src/main/java/org/MultiLangRunInference.java: ########## @@ -0,0 +1,93 @@ +package org; +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.IOException; +import org.apache.beam.model.pipeline.v1.ExternalTransforms; +import org.apache.beam.runners.core.construction.External; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.options.Description; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.options.Validation.Required; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; +import org.apache.beam.sdk.schemas.SchemaTranslation; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.util.ByteStringOutputStream; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.sdk.values.PDone; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class MultiLangRunInference { + public interface MultiLangueageOptions extends PipelineOptions { + + @Description("Path to an input file that contains labels and pixels to feed into the model") + @Required + String getInputFile(); + + void setInputFile(String value); + + @Description("Path to an input file that contains labels and pixels to feed into the model") + @Required + String getOutputFile(); + + void setOutputFile(String value); + } + + private static byte[] toStringPayloadBytes(String model) { + Row configRow = Row.withSchema(Schema.of(Field.of("model", FieldType.STRING))) + .withFieldValue("model", model) + .build(); + + ByteStringOutputStream outputStream = new ByteStringOutputStream(); + + try { + RowCoder.of(configRow.getSchema()).encode(configRow, outputStream); + } catch (IOException e) { + throw new RuntimeException(e); + } + ExternalTransforms.ExternalConfigurationPayload payload = ExternalTransforms.ExternalConfigurationPayload + .newBuilder() + .setSchema(SchemaTranslation.schemaToProto(configRow.getSchema(), false)) + .setPayload(outputStream.toByteString()) + .build(); + return payload.toByteArray(); + } Review Comment: Yeah, you should be able to just pass strings and that should be encoded (via UTF8) automatically. ########## sdks/python/apache_beam/examples/inference/multi_language/expansion_service/run_inference_expansion.py: ########## @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pytype: skip-file + +import argparse +import logging +import signal +import sys +import typing + +import grpc + +import apache_beam as beam +from apache_beam.coders import RowCoder +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerKeyedTensor +from apache_beam.pipeline import PipelineOptions +from apache_beam.portability.api import beam_artifact_api_pb2_grpc +from apache_beam.portability.api import beam_expansion_api_pb2_grpc +from apache_beam.portability.api import external_transforms_pb2 +from apache_beam.runners.portability import artifact_service +from apache_beam.runners.portability import expansion_service +from apache_beam.transforms import fully_qualified_named_transform +from apache_beam.transforms import ptransform +from apache_beam.transforms.environments import PyPIArtifactRegistry +from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder +from apache_beam.utils import thread_pool_executor +from transformers import BertConfig +from transformers import BertForMaskedLM +from transformers import BertTokenizer + +# This script provides an expansion service for a run inference transform +# with pre and post processing. +# The model used is a BertLM, base uncased model. +_LOGGER = logging.getLogger(__name__) + +# A transform that runs inference on a Bertmodel. +TEST_RUN_BERT_URN = "beam:transforms:xlang:test:run_bert" + + [email protected]_urn(TEST_RUN_BERT_URN, None) +class RunInferenceTransform(ptransform.PTransform): + class PytorchModelHandlerKeyedTensorWrapper(PytorchModelHandlerKeyedTensor): + """Wrapper to PytorchModelHandler to limit batch size to 1. + The tokenized strings generated from BertTokenizer may have different + lengths, which doesn't work with torch.stack() in current RunInference + implementation since stack() requires tensors to be the same size. + Restricting max_batch_size to 1 means there is only 1 example per + `batch` in the run_inference() call. + """ + def batch_elements_kwargs(self): + return {'max_batch_size': 1} + + class Preprocess(beam.DoFn): + def __init__(self, tokenizer): + # self._model_name = model_name + logging.info('Starting Preprocess') + # self._tokenizer = BertTokenizer.from_pretrained(self._model_name) + self._tokenizer = tokenizer + logging.info('Tokenizer loaded') + + def process(self, text: str): + import torch + if len(text.strip()) > 0: + logging.info('Preprocessing Line: %s', text) + text_list = text.split() + masked_text = ' '.join(text_list[:-2] + ['[MASK]', text_list[-1]]) + tokens = self._tokenizer(masked_text, return_tensors='pt') + tokens = {key: torch.squeeze(val) for key, val in tokens.items()} + return [(text, tokens)] + + class Postprocess(beam.DoFn): + def __init__(self, bert_tokenizer): + logging.info('Starting Postprocess') + self.bert_tokenizer = bert_tokenizer + + def process(self, element: typing.Tuple[str, PredictionResult]) \ + -> typing.Iterable[str]: + text, prediction_result = element + inputs = prediction_result.example + logits = prediction_result.inference['logits'] + mask_token_index = ( + inputs['input_ids'] == self.bert_tokenizer.mask_token_id).nonzero( + as_tuple=True)[0] + predicted_token_id = logits[mask_token_index].argmax(axis=-1) + decoded_word = self.bert_tokenizer.decode(predicted_token_id) + text = text.replace('.', '') + yield text + '\n Predicted word: ' + decoded_word.upper() + + def __init__(self, model): + self._model = model + # can also save the model config and tokenizer in gcs and load in + self._model_config = BertConfig.from_pretrained(self._model) + self._tokenizer = BertTokenizer.from_pretrained(self._model) + self._model_handler = self.PytorchModelHandlerKeyedTensorWrapper( + state_dict_path=( + "gs://apache-beam-x-lang-testing/input/" + "bert-model/bert-base-uncased.pth"), + model_class=BertForMaskedLM, + model_params={'config': self._model_config}, + device='cuda:0') + + def expand(self, pcoll): + return ( + pcoll + | 'Preprocess' >> beam.ParDo(self.Preprocess(self._tokenizer)) + | 'Inference' >> RunInference(KeyedModelHandler(self._model_handler)) + | 'Postprocess' >> beam.ParDo(self.Postprocess( + self._tokenizer)).with_input_types(typing.Iterable[str])) + + def to_runner_api_parameter(self, unused_context): + return TEST_RUN_BERT_URN, ImplicitSchemaPayloadBuilder( + {'model': self._model}).payload() + + @staticmethod + def from_runner_api_parameter(unused_ptransform, payload, unused_context): + return RunInferenceTransform(parse_string_payload(payload)['model']) + + [email protected]_urn('payload', bytes) Review Comment: All this should be updated/simplified by using the PythonExternalTransform API (pls see the other comment). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
