This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 97cb452e4b9 Vllm model handler (#32410) 97cb452e4b9 is described below commit 97cb452e4b92603ea0a28a225c6ff1d60200840e Author: Danny McCormick <dannymccorm...@google.com> AuthorDate: Tue Sep 24 14:00:43 2024 -0400 Vllm model handler (#32410) * Vllm first pass [wip] * Example for integration tests wip * Still wip * Test changes * Dockerfile improvements * Remove bad change * Clean up test args * clean up invocation * string fix * string fix * clean up * lint * Get tests working with 5xx driver * cleanup * Fixes, everything is now working * Batching * lint * Feedback + CHANGES.md --- .github/trigger_files/beam_PostCommit_Python.json | 4 +- CHANGES.md | 15 +- build.gradle.kts | 1 + .../apache_beam/examples/inference/README.md | 80 ++++++ .../examples/inference/vllm_text_completion.py | 162 +++++++++++ .../ml/inference/test_resources/vllm.dockerfile | 47 ++++ .../apache_beam/ml/inference/vllm_inference.py | 312 +++++++++++++++++++++ sdks/python/setup.py | 1 + sdks/python/test-suites/dataflow/common.gradle | 39 +++ 9 files changed, 646 insertions(+), 15 deletions(-) diff --git a/.github/trigger_files/beam_PostCommit_Python.json b/.github/trigger_files/beam_PostCommit_Python.json index d01a47e7265..30ee463ad4e 100644 --- a/.github/trigger_files/beam_PostCommit_Python.json +++ b/.github/trigger_files/beam_PostCommit_Python.json @@ -1,5 +1,5 @@ { - "comment": "modify this file in a trivial way to cause this test suite to run.", - "modification": 1 + "comment": "Modify this file in a trivial way to cause this test suite to run.", + "modification": 2 } diff --git a/CHANGES.md b/CHANGES.md index d58ceffeb41..c123a8e1a4d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -57,18 +57,13 @@ ## Highlights -* New highly anticipated feature X added to Python SDK ([#X](https://github.com/apache/beam/issues/X)). -* New highly anticipated feature Y added to Java SDK ([#Y](https://github.com/apache/beam/issues/Y)). - -## I/Os - -* Support for X source added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528)) ## New Features / Improvements * Dataflow worker can install packages from Google Artifact Registry Python repositories (Python) ([#32123](https://github.com/apache/beam/issues/32123)). * Added support for Zstd codec in SerializableAvroCodecFactory (Java) ([#32349](https://github.com/apache/beam/issues/32349)) -* X feature added (Java/Python) ([#X](https://github.com/apache/beam/issues/X)). +* Added support for using vLLM in the RunInference transform (Python) ([#32528](https://github.com/apache/beam/issues/32528)) ## Breaking Changes @@ -77,11 +72,9 @@ as strings rather than silently coerced (and possibly truncated) to numeric values. To retain the old behavior, pass `dtype=True` (or any other value accepted by `pandas.read_json`). -* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)). ## Deprecations -* X behavior is deprecated and will be removed in X versions ([#X](https://github.com/apache/beam/issues/X)). * Python 3.8 is reaching EOL and support is being removed in Beam 2.61.0. The 2.60.0 release will warn users when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) @@ -92,10 +85,6 @@ when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192)) ## Security Fixes * Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] (Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)). -## Known Issues - -* ([#X](https://github.com/apache/beam/issues/X)). - # [2.59.0] - 2024-09-11 ## Highlights diff --git a/build.gradle.kts b/build.gradle.kts index d74cae3267e..38b58b6979e 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -542,6 +542,7 @@ tasks.register("python312PostCommit") { dependsOn(":sdks:python:test-suites:direct:py312:postCommitIT") dependsOn(":sdks:python:test-suites:direct:py312:hdfsIntegrationTest") dependsOn(":sdks:python:test-suites:portable:py312:postCommitPy312") + dependsOn(":sdks:python:test-suites:dataflow:py312:inferencePostCommitITPy312") } tasks.register("portablePythonPreCommit") { diff --git a/sdks/python/apache_beam/examples/inference/README.md b/sdks/python/apache_beam/examples/inference/README.md index 3bb68440ed6..f9c5af43696 100644 --- a/sdks/python/apache_beam/examples/inference/README.md +++ b/sdks/python/apache_beam/examples/inference/README.md @@ -853,6 +853,7 @@ path/to/my/image2: dandelions (78) Each line represents a prediction of the flower type along with the confidence in that prediction. --- + ## Text classifcation with a Vertex AI LLM [`vertex_ai_llm_text_classification.py`](./vertex_ai_llm_text_classification.py) contains an implementation for a RunInference pipeline that performs image classification using a model hosted on Vertex AI (based on https://cloud.google.com/vertex-ai/docs/tutorials/image-recognition-custom). @@ -882,4 +883,83 @@ This writes the output to the output file with contents like: ``` Each line represents a tuple containing the example, a [PredictionResult](https://beam.apache.org/releases/pydoc/2.40.0/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult) object with the response from the model in the inference field, and the endpoint id representing the model id. +--- + +## Text completion with vLLM + +[`vllm_text_completion.py`](./vllm_text_completion.py) contains an implementation for a RunInference pipeline that performs text completion using a local [vLLM](https://docs.vllm.ai/en/latest/) server. + +The pipeline reads in a set of text prompts or past messages, uses RunInference to spin up a local inference server and perform inference, and then writes the predictions to a text file. + +### Model for text completion + +To use this transform, you can use any [LLM supported by vLLM](https://docs.vllm.ai/en/latest/models/supported_models.html). + +### Running `vllm_text_completion.py` + +To run the text completion pipeline locally using the Facebook opt 125M model, use the following command. +```sh +python -m apache_beam.examples.inference.vllm_text_completion \ + --model "facebook/opt-125m" \ + --output 'path/to/output/file.txt' \ + <... aditional pipeline arguments to configure runner if not running in GPU environment ...> +``` + +You will either need to run this locally with a GPU accelerator or remotely on a runner that supports acceleration. +For example, you could run this on Dataflow with a GPU with the following command: + +```sh +python -m apache_beam.examples.inference.vllm_text_completion \ + --model "facebook/opt-125m" \ + --output 'gs://path/to/output/file.txt' \ + --runner dataflow \ + --project <gcp project> \ + --region us-central1 \ + --temp_location <temp gcs location> \ + --worker_harness_container_image "gcr.io/apache-beam-testing/beam-ml/vllm:latest" \ + --machine_type "n1-standard-4" \ + --dataflow_service_options "worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx" \ + --staging_location <temp gcs location> +``` + +Make sure to enable the 5xx driver since vLLM only works with 5xx drivers, not 4xx. + +This writes the output to the output file location with contents like: + +``` +'Hello, my name is', PredictionResult(example={'prompt': 'Hello, my name is'}, inference=Completion(id='cmpl-5f5113a317c949309582b1966511ffc4', choices=[CompletionChoice(finish_reason='length', index=0, logprobs=None, text=' Joel, my dad is Anton Harriman and my wife is Lydia. ', stop_reason=None)], created=1714064548, model='facebook/opt-125m', object='text_completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=16, prompt_tokens=6, total_tokens=22))}) +``` +Each line represents a tuple containing the example, a [PredictionResult](https://beam.apache.org/releases/pydoc/2.40.0/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult) object with the response from the model in the inference field. + +You can also choose to run with chat examples. Doing this requires 2 steps: + +1) Upload a [chat_template](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#chat-template) to a filestore which is accessible from your job's environment (e.g. a public Google Cloud Storage bucket). You can copy [this sample template](https://storage.googleapis.com/apache-beam-ml/additional_files/sample_chat_template.jinja) to get started. You can skip this step if using a model other than `facebook/opt-125m` and you know your model provides a chat template. +2) Add the `--chat true` and `--chat_template <gs://path/to/your/file>` parameters: + +```sh +python -m apache_beam.examples.inference.vllm_text_completion \ + --model "facebook/opt-125m" \ + --output 'gs://path/to/output/file.txt' \ + --chat true \ + --chat_template gs://path/to/your/file \ + <... aditional pipeline arguments to configure runner if not running in GPU environment ...> +``` + +This will configure the pipeline to run against a sequence of previous messages instead of a single text completion prompt. +For example, it might run against: + +``` +[ + OpenAIChatMessage(role='user', content='What is an example of a type of penguin?'), + OpenAIChatMessage(role='system', content='An emperor penguin is a type of penguin.'), + OpenAIChatMessage(role='user', content='Tell me about them') +], +``` + +and produce the following result in your output file location: + +``` +An emperor penguin is an adorable creature that lives in Antarctica. +``` + --- \ No newline at end of file diff --git a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py new file mode 100644 index 00000000000..3cf7d04cb03 --- /dev/null +++ b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" A sample pipeline using the RunInference API to interface with an LLM using +vLLM. Takes in a set of prompts or lists of previous messages and produces +responses using a model of choice. + +Requires a GPU runtime with vllm, openai, and apache-beam installed to run +correctly. +""" + +import argparse +import logging +from typing import Iterable + +import apache_beam as beam +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.vllm_inference import OpenAIChatMessage +from apache_beam.ml.inference.vllm_inference import VLLMChatModelHandler +from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.runners.runner import PipelineResult + +COMPLETION_EXAMPLES = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + "John cena is", +] + +CHAT_EXAMPLES = [ + [ + OpenAIChatMessage( + role='user', content='What is an example of a type of penguin?'), + OpenAIChatMessage( + role='assistant', content='Emperor penguin is a type of penguin.'), + OpenAIChatMessage(role='user', content='Tell me about them') + ], + [ + OpenAIChatMessage( + role='user', content='What colors are in the rainbow?'), + OpenAIChatMessage( + role='assistant', + content='Red, orange, yellow, green, blue, indigo, and violet.'), + OpenAIChatMessage(role='user', content='Do other colors ever appear?') + ], + [ + OpenAIChatMessage( + role='user', content='Who is the president of the United States?') + ], + [ + OpenAIChatMessage(role='user', content='What state is Fargo in?'), + OpenAIChatMessage(role='assistant', content='It is in North Dakota.'), + OpenAIChatMessage(role='user', content='How many people live there?'), + OpenAIChatMessage( + role='assistant', + content='Approximately 130,000 people live in Fargo, North Dakota.' + ), + OpenAIChatMessage(role='user', content='What is Fargo known for?'), + ], + [ + OpenAIChatMessage( + role='user', content='How many fish are in the ocean?'), + ], +] + + +def parse_known_args(argv): + """Parses args for the workflow.""" + parser = argparse.ArgumentParser() + parser.add_argument( + '--model', + dest='model', + type=str, + required=False, + default='facebook/opt-125m', + help='LLM to use for task') + parser.add_argument( + '--output', + dest='output', + type=str, + required=True, + help='Path to save output predictions.') + parser.add_argument( + '--chat', + dest='chat', + type=bool, + required=False, + default=False, + help='Whether to use chat model handler and examples') + parser.add_argument( + '--chat_template', + dest='chat_template', + type=str, + required=False, + default=None, + help='Chat template to use for chat example.') + return parser.parse_known_args(argv) + + +class PostProcessor(beam.DoFn): + def process(self, element: PredictionResult) -> Iterable[str]: + yield str(element.example) + ": " + str(element.inference) + + +def run( + argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: + """ + Args: + argv: Command line arguments defined for this example. + save_main_session: Used for internal testing. + test_pipeline: Used for internal testing. + """ + known_args, pipeline_args = parse_known_args(argv) + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + + model_handler = VLLMCompletionsModelHandler(model_name=known_args.model) + input_examples = COMPLETION_EXAMPLES + + if known_args.chat: + model_handler = VLLMChatModelHandler( + model_name=known_args.model, + chat_template_path=known_args.chat_template) + input_examples = CHAT_EXAMPLES + + pipeline = test_pipeline + if not test_pipeline: + pipeline = beam.Pipeline(options=pipeline_options) + + examples = pipeline | "Create examples" >> beam.Create(input_examples) + predictions = examples | "RunInference" >> RunInference(model_handler) + process_output = predictions | "Process Predictions" >> beam.ParDo( + PostProcessor()) + _ = process_output | "WriteOutput" >> beam.io.WriteToText( + known_args.output, shard_name_template='', append_trailing_newlines=True) + + result = pipeline.run() + result.wait_until_finish() + return result + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + run() diff --git a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile new file mode 100644 index 00000000000..5abbffdc5a2 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Used for any vLLM integration test + +FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 + +RUN apt update +RUN apt install software-properties-common -y +RUN add-apt-repository ppa:deadsnakes/ppa +RUN apt update + +ARG DEBIAN_FRONTEND=noninteractive + +RUN apt install python3.12 -y +RUN apt install python3.12-venv -y +RUN apt install python3.12-dev -y +RUN rm /usr/bin/python3 +RUN ln -s python3.12 /usr/bin/python3 +RUN python3 --version +RUN apt-get install -y curl +RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12 && pip install --upgrade pip + +RUN pip install --no-cache-dir -vvv apache-beam[gcp]==2.58.1 +RUN pip install openai vllm + +RUN apt install libcairo2-dev pkg-config python3-dev -y +RUN pip install pycairo + +# Copy the Apache Beam worker dependencies from the Beam Python 3.8 SDK image. +COPY --from=apache/beam_python3.12_sdk:2.58.1 /opt/apache/beam /opt/apache/beam + +# Set the entrypoint to Apache Beam SDK worker launcher. +ENTRYPOINT [ "/opt/apache/beam/boot" ] diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py b/sdks/python/apache_beam/ml/inference/vllm_inference.py new file mode 100644 index 00000000000..28890083d93 --- /dev/null +++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py @@ -0,0 +1,312 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# pytype: skip-file + +import logging +import os +import subprocess +import threading +import time +import uuid +from dataclasses import dataclass +from typing import Any +from typing import Dict +from typing import Iterable +from typing import Optional +from typing import Sequence +from typing import Tuple + +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.utils import subprocess_server +from openai import OpenAI + +try: + import vllm # pylint: disable=unused-import + logging.info('vllm module successfully imported.') +except ModuleNotFoundError: + msg = 'vllm module was not found. This is ok as long as the specified ' \ + 'runner has vllm dependencies installed.' + logging.warning(msg) + +__all__ = [ + 'OpenAIChatMessage', + 'VLLMCompletionsModelHandler', + 'VLLMChatModelHandler', +] + + +@dataclass(frozen=True) +class OpenAIChatMessage(): + """" + Dataclass containing previous chat messages in conversation. + Role is the entity that sent the message (either 'user' or 'system'). + Content is the contents of the message. + """ + role: str + content: str + + +def start_process(cmd) -> Tuple[subprocess.Popen, int]: + port, = subprocess_server.pick_port(None) + cmd = [arg.replace('{{PORT}}', str(port)) for arg in cmd] # pylint: disable=not-an-iterable + logging.info("Starting service with %s", str(cmd).replace("',", "'")) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + # Emit the output of this command as info level logging. + def log_stdout(): + line = process.stdout.readline() + while line: + # The log obtained from stdout is bytes, decode it into string. + # Remove newline via rstrip() to not print an empty line. + logging.info(line.decode(errors='backslashreplace').rstrip()) + line = process.stdout.readline() + + t = threading.Thread(target=log_stdout) + t.daemon = True + t.start() + return process, port + + +def getVLLMClient(port) -> OpenAI: + openai_api_key = "EMPTY" + openai_api_base = f"http://localhost:{port}/v1" + return OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, + ) + + +class _VLLMModelServer(): + def __init__(self, model_name: str, vllm_server_kwargs: Dict[str, str]): + self._model_name = model_name + self._vllm_server_kwargs = vllm_server_kwargs + self._server_started = False + self._server_process = None + self._server_port: int = -1 + + self.start_server() + + def start_server(self, retries=3): + if not self._server_started: + server_cmd = [ + 'python', + '-m', + 'vllm.entrypoints.openai.api_server', + '--model', + self._model_name, + '--port', + '{{PORT}}', + ] + for k, v in self._vllm_server_kwargs.items(): + server_cmd.append(f'--{k}') + server_cmd.append(v) + self._server_process, self._server_port = start_process(server_cmd) + + self.check_connectivity() + + def get_server_port(self) -> int: + if not self._server_started: + self.start_server() + return self._server_port + + def check_connectivity(self, retries=3): + client = getVLLMClient(self._server_port) + while self._server_process.poll() is None: + try: + models = client.models.list().data + logging.info('models: %s' % models) + if len(models) > 0: + self._server_started = True + return + except: # pylint: disable=bare-except + pass + # Sleep while bringing up the process + time.sleep(5) + + if retries == 0: + self._server_started = False + raise Exception( + "Failed to start vLLM server, polling process exited with code " + + "%s. Next time a request is tried, the server will be restarted" % + self._server_process.poll()) + else: + self.start_server(retries - 1) + + +class VLLMCompletionsModelHandler(ModelHandler[str, + PredictionResult, + _VLLMModelServer]): + def __init__( + self, + model_name: str, + vllm_server_kwargs: Optional[Dict[str, str]] = None): + """Implementation of the ModelHandler interface for vLLM using text as + input. + + Example Usage:: + + pcoll | RunInference(VLLMModelHandler(model_name='facebook/opt-125m')) + + Args: + model_name: The vLLM model. See + https://docs.vllm.ai/en/latest/models/supported_models.html for + supported models. + vllm_server_kwargs: Any additional kwargs to be passed into your vllm + server when it is being created. Will be invoked using + `python -m vllm.entrypoints.openai.api_serverv <beam provided args> + <vllm_server_kwargs>`. For example, you could pass + `{'echo': 'true'}` to prepend new messages with the previous message. + For a list of possible kwargs, see + https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-completions-api + """ + self._model_name = model_name + self._vllm_server_kwargs: Dict[str, str] = vllm_server_kwargs or {} + self._env_vars = {} + + def load_model(self) -> _VLLMModelServer: + return _VLLMModelServer(self._model_name, self._vllm_server_kwargs) + + def run_inference( + self, + batch: Sequence[str], + model: _VLLMModelServer, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + """Runs inferences on a batch of text strings. + + Args: + batch: A sequence of examples as text strings. + model: A _VLLMModelServer containing info for connecting to the server. + inference_args: Any additional arguments for an inference. + + Returns: + An Iterable of type PredictionResult. + """ + client = getVLLMClient(model.get_server_port()) + inference_args = inference_args or {} + predictions = [] + # TODO(https://github.com/apache/beam/issues/32528): We should add support + # for taking in batches and doing a bunch of async calls. That will end up + # being more efficient when we can do in bundle batching. + for prompt in batch: + try: + completion = client.completions.create( + model=self._model_name, prompt=prompt, **inference_args) + predictions.append(completion) + except Exception as e: + model.check_connectivity() + raise e + + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + + def share_model_across_processes(self) -> bool: + return True + + +class VLLMChatModelHandler(ModelHandler[Sequence[OpenAIChatMessage], + PredictionResult, + _VLLMModelServer]): + def __init__( + self, + model_name: str, + chat_template_path: Optional[str] = None, + vllm_server_kwargs: Optional[Dict[str, str]] = None): + """ Implementation of the ModelHandler interface for vLLM using previous + messages as input. + + Example Usage:: + + pcoll | RunInference(VLLMModelHandler(model_name='facebook/opt-125m')) + + Args: + model_name: The vLLM model. See + https://docs.vllm.ai/en/latest/models/supported_models.html for + supported models. + chat_template_path: Path to a chat template. This file must be accessible + from your runner's execution environment, so it is recommended to use + a cloud based file storage system (e.g. Google Cloud Storage). + For info on chat templates, see: + https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#chat-template + vllm_server_kwargs: Any additional kwargs to be passed into your vllm + server when it is being created. Will be invoked using + `python -m vllm.entrypoints.openai.api_serverv <beam provided args> + <vllm_server_kwargs>`. For example, you could pass + `{'echo': 'true'}` to prepend new messages with the previous message. + For a list of possible kwargs, see + https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-chat-api + """ + self._model_name = model_name + self._vllm_server_kwargs: Dict[str, str] = vllm_server_kwargs or {} + self._env_vars = {} + self._chat_template_path = chat_template_path + self._chat_file = f'template-{uuid.uuid4().hex}.jinja' + + def load_model(self) -> _VLLMModelServer: + chat_template_contents = '' + if self._chat_template_path is not None: + local_chat_template_path = os.path.join(os.getcwd(), self._chat_file) + if not os.path.exists(local_chat_template_path): + with FileSystems.open(self._chat_template_path) as fin: + chat_template_contents = fin.read().decode() + with open(local_chat_template_path, 'a') as f: + f.write(chat_template_contents) + self._vllm_server_kwargs['chat_template'] = local_chat_template_path + + return _VLLMModelServer(self._model_name, self._vllm_server_kwargs) + + def run_inference( + self, + batch: Sequence[Sequence[OpenAIChatMessage]], + model: _VLLMModelServer, + inference_args: Optional[Dict[str, Any]] = None + ) -> Iterable[PredictionResult]: + """Runs inferences on a batch of text strings. + + Args: + batch: A sequence of examples as OpenAI messages. + model: A _VLLMModelServer for connecting to the spun up server. + inference_args: Any additional arguments for an inference. + + Returns: + An Iterable of type PredictionResult. + """ + client = getVLLMClient(model.get_server_port()) + inference_args = inference_args or {} + predictions = [] + # TODO(https://github.com/apache/beam/issues/32528): We should add support + # for taking in batches and doing a bunch of async calls. That will end up + # being more efficient when we can do in bundle batching. + for messages in batch: + formatted = [] + for message in messages: + formatted.append({"role": message.role, "content": message.content}) + try: + completion = client.chat.completions.create( + model=self._model_name, messages=formatted, **inference_args) + predictions.append(completion) + except Exception as e: + model.check_connectivity() + raise e + + return [PredictionResult(x, y) for x, y in zip(batch, predictions)] + + def share_model_across_processes(self) -> bool: + return True diff --git a/sdks/python/setup.py b/sdks/python/setup.py index ddd7bfea52c..721cb4c1a8d 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -405,6 +405,7 @@ if __name__ == '__main__': # https://github.com/sphinx-doc/sphinx/issues/9727 'docutils==0.17.1', 'pandas<2.2.0', + 'openai' ], 'test': [ 'docstring-parser>=0.15,<1.0', diff --git a/sdks/python/test-suites/dataflow/common.gradle b/sdks/python/test-suites/dataflow/common.gradle index e5d301ecbe1..6bca904c1a6 100644 --- a/sdks/python/test-suites/dataflow/common.gradle +++ b/sdks/python/test-suites/dataflow/common.gradle @@ -424,6 +424,39 @@ def tensorRTTests = tasks.create("tensorRTtests") { } } +def vllmTests = tasks.create("vllmTests") { + dependsOn 'installGcpTest' + dependsOn ':sdks:python:sdist' + doLast { + def testOpts = basicPytestOpts + def argMap = [ + "runner": "DataflowRunner", + "machine_type":"n1-standard-4", + // TODO(https://github.com/apache/beam/issues/22651): Build docker image for VLLM tests during Run time. + // This would also enable to use wheel "--sdk_location" as other tasks, and eliminate distTarBall dependency + // declaration for this project. + // Right now, this is built from https://github.com/apache/beam/blob/master/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile + "sdk_container_image": "us.gcr.io/apache-beam-testing/python-postcommit-it/vllm:latest", + "sdk_location": files(configurations.distTarBall.files).singleFile, + "project": "apache-beam-testing", + "region": "us-central1", + "model": "facebook/opt-125m", + "output": "gs://apache-beam-ml/outputs/vllm_predictions.txt", + "disk_size_gb": 75 + ] + def cmdArgs = mapToArgString(argMap) + // Exec one version with and one version without the chat option + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && pip install openai && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx'" + } + exec { + executable 'sh' + args '-c', ". ${envdir}/bin/activate && pip install openai && python -m apache_beam.examples.inference.vllm_text_completion $cmdArgs --chat true --chat_template 'gs://apache-beam-ml/additional_files/sample_chat_template.jinja' --experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx'" + } + } +} + // Vertex AI RunInference IT tests task vertexAIInferenceTest { dependsOn 'initializeForDataflowJob' @@ -521,6 +554,12 @@ project.tasks.register("inferencePostCommitIT") { ] } +project.tasks.register("inferencePostCommitITPy312") { + dependsOn = [ + 'vllmTests', + ] +} + // Create cross-language tasks for running tests against Java expansion service(s) def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing'