This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 97cb452e4b9 Vllm model handler (#32410)
97cb452e4b9 is described below

commit 97cb452e4b92603ea0a28a225c6ff1d60200840e
Author: Danny McCormick <dannymccorm...@google.com>
AuthorDate: Tue Sep 24 14:00:43 2024 -0400

    Vllm model handler (#32410)
    
    * Vllm first pass [wip]
    
    * Example for integration tests wip
    
    * Still wip
    
    * Test changes
    
    * Dockerfile improvements
    
    * Remove bad change
    
    * Clean up test args
    
    * clean up invocation
    
    * string fix
    
    * string fix
    
    * clean up
    
    * lint
    
    * Get tests working with 5xx driver
    
    * cleanup
    
    * Fixes, everything is now working
    
    * Batching
    
    * lint
    
    * Feedback + CHANGES.md
---
 .github/trigger_files/beam_PostCommit_Python.json  |   4 +-
 CHANGES.md                                         |  15 +-
 build.gradle.kts                                   |   1 +
 .../apache_beam/examples/inference/README.md       |  80 ++++++
 .../examples/inference/vllm_text_completion.py     | 162 +++++++++++
 .../ml/inference/test_resources/vllm.dockerfile    |  47 ++++
 .../apache_beam/ml/inference/vllm_inference.py     | 312 +++++++++++++++++++++
 sdks/python/setup.py                               |   1 +
 sdks/python/test-suites/dataflow/common.gradle     |  39 +++
 9 files changed, 646 insertions(+), 15 deletions(-)

diff --git a/.github/trigger_files/beam_PostCommit_Python.json 
b/.github/trigger_files/beam_PostCommit_Python.json
index d01a47e7265..30ee463ad4e 100644
--- a/.github/trigger_files/beam_PostCommit_Python.json
+++ b/.github/trigger_files/beam_PostCommit_Python.json
@@ -1,5 +1,5 @@
 {
-  "comment": "modify this file in a trivial way to cause this test suite to 
run.",
-  "modification": 1
+  "comment": "Modify this file in a trivial way to cause this test suite to 
run.",
+  "modification": 2
 }
 
diff --git a/CHANGES.md b/CHANGES.md
index d58ceffeb41..c123a8e1a4d 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -57,18 +57,13 @@
 
 ## Highlights
 
-* New highly anticipated feature X added to Python SDK 
([#X](https://github.com/apache/beam/issues/X)).
-* New highly anticipated feature Y added to Java SDK 
([#Y](https://github.com/apache/beam/issues/Y)).
-
-## I/Os
-
-* Support for X source added (Java/Python) 
([#X](https://github.com/apache/beam/issues/X)).
+* Added support for using vLLM in the RunInference transform (Python) 
([#32528](https://github.com/apache/beam/issues/32528))
 
 ## New Features / Improvements
 
 * Dataflow worker can install packages from Google Artifact Registry Python 
repositories (Python) ([#32123](https://github.com/apache/beam/issues/32123)).
 * Added support for Zstd codec in SerializableAvroCodecFactory (Java) 
([#32349](https://github.com/apache/beam/issues/32349))
-* X feature added (Java/Python) 
([#X](https://github.com/apache/beam/issues/X)).
+* Added support for using vLLM in the RunInference transform (Python) 
([#32528](https://github.com/apache/beam/issues/32528))
 
 ## Breaking Changes
 
@@ -77,11 +72,9 @@
   as strings rather than silently coerced (and possibly truncated) to numeric
   values.  To retain the old behavior, pass `dtype=True` (or any other value
   accepted by `pandas.read_json`).
-* X behavior was changed ([#X](https://github.com/apache/beam/issues/X)).
 
 ## Deprecations
 
-* X behavior is deprecated and will be removed in X versions 
([#X](https://github.com/apache/beam/issues/X)).
 * Python 3.8 is reaching EOL and support is being removed in Beam 2.61.0. The 
2.60.0 release will warn users
 when running on 3.8. ([#31192](https://github.com/apache/beam/issues/31192))
 
@@ -92,10 +85,6 @@ when running on 3.8. 
([#31192](https://github.com/apache/beam/issues/31192))
 ## Security Fixes
 * Fixed (CVE-YYYY-NNNN)[https://www.cve.org/CVERecord?id=CVE-YYYY-NNNN] 
(Java/Python/Go) ([#X](https://github.com/apache/beam/issues/X)).
 
-## Known Issues
-
-* ([#X](https://github.com/apache/beam/issues/X)).
-
 # [2.59.0] - 2024-09-11
 
 ## Highlights
diff --git a/build.gradle.kts b/build.gradle.kts
index d74cae3267e..38b58b6979e 100644
--- a/build.gradle.kts
+++ b/build.gradle.kts
@@ -542,6 +542,7 @@ tasks.register("python312PostCommit") {
   dependsOn(":sdks:python:test-suites:direct:py312:postCommitIT")
   dependsOn(":sdks:python:test-suites:direct:py312:hdfsIntegrationTest")
   dependsOn(":sdks:python:test-suites:portable:py312:postCommitPy312")
+  
dependsOn(":sdks:python:test-suites:dataflow:py312:inferencePostCommitITPy312")
 }
 
 tasks.register("portablePythonPreCommit") {
diff --git a/sdks/python/apache_beam/examples/inference/README.md 
b/sdks/python/apache_beam/examples/inference/README.md
index 3bb68440ed6..f9c5af43696 100644
--- a/sdks/python/apache_beam/examples/inference/README.md
+++ b/sdks/python/apache_beam/examples/inference/README.md
@@ -853,6 +853,7 @@ path/to/my/image2: dandelions (78)
 Each line represents a prediction of the flower type along with the confidence 
in that prediction.
 
 ---
+
 ## Text classifcation with a Vertex AI LLM
 
 
[`vertex_ai_llm_text_classification.py`](./vertex_ai_llm_text_classification.py)
 contains an implementation for a RunInference pipeline that performs image 
classification using a model hosted on Vertex AI (based on 
https://cloud.google.com/vertex-ai/docs/tutorials/image-recognition-custom).
@@ -882,4 +883,83 @@ This writes the output to the output file with contents 
like:
 ```
 Each line represents a tuple containing the example, a 
[PredictionResult](https://beam.apache.org/releases/pydoc/2.40.0/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult)
 object with the response from the model in the inference field, and the 
endpoint id representing the model id.
+---
+
+## Text completion with vLLM
+
+[`vllm_text_completion.py`](./vllm_text_completion.py) contains an 
implementation for a RunInference pipeline that performs text completion using 
a local [vLLM](https://docs.vllm.ai/en/latest/) server.
+
+The pipeline reads in a set of text prompts or past messages, uses 
RunInference to spin up a local inference server and perform inference, and 
then writes the predictions to a text file.
+
+### Model for text completion
+
+To use this transform, you can use any [LLM supported by 
vLLM](https://docs.vllm.ai/en/latest/models/supported_models.html).
+
+### Running `vllm_text_completion.py`
+
+To run the text completion pipeline locally using the Facebook opt 125M model, 
use the following command.
+```sh
+python -m apache_beam.examples.inference.vllm_text_completion \
+  --model "facebook/opt-125m" \
+  --output 'path/to/output/file.txt' \
+  <... aditional pipeline arguments to configure runner if not running in GPU 
environment ...>
+```
+
+You will either need to run this locally with a GPU accelerator or remotely on 
a runner that supports acceleration.
+For example, you could run this on Dataflow with a GPU with the following 
command:
+
+```sh
+python -m apache_beam.examples.inference.vllm_text_completion \
+  --model "facebook/opt-125m" \
+  --output 'gs://path/to/output/file.txt' \
+  --runner dataflow \
+  --project <gcp project> \
+  --region us-central1 \
+  --temp_location <temp gcs location> \
+  --worker_harness_container_image 
"gcr.io/apache-beam-testing/beam-ml/vllm:latest" \
+  --machine_type "n1-standard-4" \
+  --dataflow_service_options 
"worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx" \
+  --staging_location <temp gcs location>
+```
+
+Make sure to enable the 5xx driver since vLLM only works with 5xx drivers, not 
4xx.
+
+This writes the output to the output file location with contents like:
+
+```
+'Hello, my name is', PredictionResult(example={'prompt': 'Hello, my name is'}, 
inference=Completion(id='cmpl-5f5113a317c949309582b1966511ffc4', 
choices=[CompletionChoice(finish_reason='length', index=0, logprobs=None, 
text=' Joel, my dad is Anton Harriman and my wife is Lydia. ', 
stop_reason=None)], created=1714064548, model='facebook/opt-125m', 
object='text_completion', system_fingerprint=None, 
usage=CompletionUsage(completion_tokens=16, prompt_tokens=6, total_tokens=22))})
+```
+Each line represents a tuple containing the example, a 
[PredictionResult](https://beam.apache.org/releases/pydoc/2.40.0/apache_beam.ml.inference.base.html#apache_beam.ml.inference.base.PredictionResult)
 object with the response from the model in the inference field.
+
+You can also choose to run with chat examples. Doing this requires 2 steps:
+
+1) Upload a 
[chat_template](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#chat-template)
 to a filestore which is accessible from your job's environment (e.g. a public 
Google Cloud Storage bucket). You can copy [this sample 
template](https://storage.googleapis.com/apache-beam-ml/additional_files/sample_chat_template.jinja)
 to get started. You can skip this step if using a model other than 
`facebook/opt-125m` and you know your model provides a chat template.
+2) Add the `--chat true` and `--chat_template <gs://path/to/your/file>` 
parameters:
+
+```sh
+python -m apache_beam.examples.inference.vllm_text_completion \
+  --model "facebook/opt-125m" \
+  --output 'gs://path/to/output/file.txt' \
+  --chat true \
+  --chat_template gs://path/to/your/file \
+  <... aditional pipeline arguments to configure runner if not running in GPU 
environment ...>
+```
+
+This will configure the pipeline to run against a sequence of previous 
messages instead of a single text completion prompt.
+For example, it might run against:
+
+```
+[
+    OpenAIChatMessage(role='user', content='What is an example of a type of 
penguin?'),
+    OpenAIChatMessage(role='system', content='An emperor penguin is a type of 
penguin.'),
+    OpenAIChatMessage(role='user', content='Tell me about them')
+],
+```
+
+and produce the following result in your output file location:
+
+```
+An emperor penguin is an adorable creature that lives in Antarctica.
+```
+
 ---
\ No newline at end of file
diff --git a/sdks/python/apache_beam/examples/inference/vllm_text_completion.py 
b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py
new file mode 100644
index 00000000000..3cf7d04cb03
--- /dev/null
+++ b/sdks/python/apache_beam/examples/inference/vllm_text_completion.py
@@ -0,0 +1,162 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+""" A sample pipeline using the RunInference API to interface with an LLM using
+vLLM. Takes in a set of prompts or lists of previous messages and produces
+responses using a model of choice.
+
+Requires a GPU runtime with vllm, openai, and apache-beam installed to run
+correctly.
+"""
+
+import argparse
+import logging
+from typing import Iterable
+
+import apache_beam as beam
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.ml.inference.vllm_inference import OpenAIChatMessage
+from apache_beam.ml.inference.vllm_inference import VLLMChatModelHandler
+from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.runners.runner import PipelineResult
+
+COMPLETION_EXAMPLES = [
+    "Hello, my name is",
+    "The president of the United States is",
+    "The capital of France is",
+    "The future of AI is",
+    "John cena is",
+]
+
+CHAT_EXAMPLES = [
+    [
+        OpenAIChatMessage(
+            role='user', content='What is an example of a type of penguin?'),
+        OpenAIChatMessage(
+            role='assistant', content='Emperor penguin is a type of penguin.'),
+        OpenAIChatMessage(role='user', content='Tell me about them')
+    ],
+    [
+        OpenAIChatMessage(
+            role='user', content='What colors are in the rainbow?'),
+        OpenAIChatMessage(
+            role='assistant',
+            content='Red, orange, yellow, green, blue, indigo, and violet.'),
+        OpenAIChatMessage(role='user', content='Do other colors ever appear?')
+    ],
+    [
+        OpenAIChatMessage(
+            role='user', content='Who is the president of the United States?')
+    ],
+    [
+        OpenAIChatMessage(role='user', content='What state is Fargo in?'),
+        OpenAIChatMessage(role='assistant', content='It is in North Dakota.'),
+        OpenAIChatMessage(role='user', content='How many people live there?'),
+        OpenAIChatMessage(
+            role='assistant',
+            content='Approximately 130,000 people live in Fargo, North Dakota.'
+        ),
+        OpenAIChatMessage(role='user', content='What is Fargo known for?'),
+    ],
+    [
+        OpenAIChatMessage(
+            role='user', content='How many fish are in the ocean?'),
+    ],
+]
+
+
+def parse_known_args(argv):
+  """Parses args for the workflow."""
+  parser = argparse.ArgumentParser()
+  parser.add_argument(
+      '--model',
+      dest='model',
+      type=str,
+      required=False,
+      default='facebook/opt-125m',
+      help='LLM to use for task')
+  parser.add_argument(
+      '--output',
+      dest='output',
+      type=str,
+      required=True,
+      help='Path to save output predictions.')
+  parser.add_argument(
+      '--chat',
+      dest='chat',
+      type=bool,
+      required=False,
+      default=False,
+      help='Whether to use chat model handler and examples')
+  parser.add_argument(
+      '--chat_template',
+      dest='chat_template',
+      type=str,
+      required=False,
+      default=None,
+      help='Chat template to use for chat example.')
+  return parser.parse_known_args(argv)
+
+
+class PostProcessor(beam.DoFn):
+  def process(self, element: PredictionResult) -> Iterable[str]:
+    yield str(element.example) + ": " + str(element.inference)
+
+
+def run(
+    argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult:
+  """
+  Args:
+    argv: Command line arguments defined for this example.
+    save_main_session: Used for internal testing.
+    test_pipeline: Used for internal testing.
+  """
+  known_args, pipeline_args = parse_known_args(argv)
+  pipeline_options = PipelineOptions(pipeline_args)
+  pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
+
+  model_handler = VLLMCompletionsModelHandler(model_name=known_args.model)
+  input_examples = COMPLETION_EXAMPLES
+
+  if known_args.chat:
+    model_handler = VLLMChatModelHandler(
+        model_name=known_args.model,
+        chat_template_path=known_args.chat_template)
+    input_examples = CHAT_EXAMPLES
+
+  pipeline = test_pipeline
+  if not test_pipeline:
+    pipeline = beam.Pipeline(options=pipeline_options)
+
+  examples = pipeline | "Create examples" >> beam.Create(input_examples)
+  predictions = examples | "RunInference" >> RunInference(model_handler)
+  process_output = predictions | "Process Predictions" >> beam.ParDo(
+      PostProcessor())
+  _ = process_output | "WriteOutput" >> beam.io.WriteToText(
+      known_args.output, shard_name_template='', append_trailing_newlines=True)
+
+  result = pipeline.run()
+  result.wait_until_finish()
+  return result
+
+
+if __name__ == '__main__':
+  logging.getLogger().setLevel(logging.INFO)
+  run()
diff --git 
a/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile 
b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile
new file mode 100644
index 00000000000..5abbffdc5a2
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile
@@ -0,0 +1,47 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# Used for any vLLM integration test
+
+FROM nvidia/cuda:12.4.1-devel-ubuntu22.04
+
+RUN apt update
+RUN apt install software-properties-common -y
+RUN add-apt-repository ppa:deadsnakes/ppa
+RUN apt update
+
+ARG DEBIAN_FRONTEND=noninteractive
+
+RUN apt install python3.12 -y
+RUN apt install python3.12-venv -y
+RUN apt install python3.12-dev -y
+RUN rm /usr/bin/python3
+RUN ln -s python3.12 /usr/bin/python3
+RUN python3 --version
+RUN apt-get install -y curl
+RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.12 && pip install 
--upgrade pip
+
+RUN pip install --no-cache-dir -vvv apache-beam[gcp]==2.58.1
+RUN pip install openai vllm
+
+RUN apt install libcairo2-dev pkg-config python3-dev -y
+RUN pip install pycairo
+
+# Copy the Apache Beam worker dependencies from the Beam Python 3.8 SDK image.
+COPY --from=apache/beam_python3.12_sdk:2.58.1 /opt/apache/beam /opt/apache/beam
+
+# Set the entrypoint to Apache Beam SDK worker launcher.
+ENTRYPOINT [ "/opt/apache/beam/boot" ]
diff --git a/sdks/python/apache_beam/ml/inference/vllm_inference.py 
b/sdks/python/apache_beam/ml/inference/vllm_inference.py
new file mode 100644
index 00000000000..28890083d93
--- /dev/null
+++ b/sdks/python/apache_beam/ml/inference/vllm_inference.py
@@ -0,0 +1,312 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# pytype: skip-file
+
+import logging
+import os
+import subprocess
+import threading
+import time
+import uuid
+from dataclasses import dataclass
+from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.utils import subprocess_server
+from openai import OpenAI
+
+try:
+  import vllm  # pylint: disable=unused-import
+  logging.info('vllm module successfully imported.')
+except ModuleNotFoundError:
+  msg = 'vllm module was not found. This is ok as long as the specified ' \
+    'runner has vllm dependencies installed.'
+  logging.warning(msg)
+
+__all__ = [
+    'OpenAIChatMessage',
+    'VLLMCompletionsModelHandler',
+    'VLLMChatModelHandler',
+]
+
+
+@dataclass(frozen=True)
+class OpenAIChatMessage():
+  """"
+  Dataclass containing previous chat messages in conversation.
+  Role is the entity that sent the message (either 'user' or 'system').
+  Content is the contents of the message.
+  """
+  role: str
+  content: str
+
+
+def start_process(cmd) -> Tuple[subprocess.Popen, int]:
+  port, = subprocess_server.pick_port(None)
+  cmd = [arg.replace('{{PORT}}', str(port)) for arg in cmd]  # pylint: 
disable=not-an-iterable
+  logging.info("Starting service with %s", str(cmd).replace("',", "'"))
+  process = subprocess.Popen(
+      cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
+
+  # Emit the output of this command as info level logging.
+  def log_stdout():
+    line = process.stdout.readline()
+    while line:
+      # The log obtained from stdout is bytes, decode it into string.
+      # Remove newline via rstrip() to not print an empty line.
+      logging.info(line.decode(errors='backslashreplace').rstrip())
+      line = process.stdout.readline()
+
+  t = threading.Thread(target=log_stdout)
+  t.daemon = True
+  t.start()
+  return process, port
+
+
+def getVLLMClient(port) -> OpenAI:
+  openai_api_key = "EMPTY"
+  openai_api_base = f"http://localhost:{port}/v1";
+  return OpenAI(
+      api_key=openai_api_key,
+      base_url=openai_api_base,
+  )
+
+
+class _VLLMModelServer():
+  def __init__(self, model_name: str, vllm_server_kwargs: Dict[str, str]):
+    self._model_name = model_name
+    self._vllm_server_kwargs = vllm_server_kwargs
+    self._server_started = False
+    self._server_process = None
+    self._server_port: int = -1
+
+    self.start_server()
+
+  def start_server(self, retries=3):
+    if not self._server_started:
+      server_cmd = [
+          'python',
+          '-m',
+          'vllm.entrypoints.openai.api_server',
+          '--model',
+          self._model_name,
+          '--port',
+          '{{PORT}}',
+      ]
+      for k, v in self._vllm_server_kwargs.items():
+        server_cmd.append(f'--{k}')
+        server_cmd.append(v)
+      self._server_process, self._server_port = start_process(server_cmd)
+
+    self.check_connectivity()
+
+  def get_server_port(self) -> int:
+    if not self._server_started:
+      self.start_server()
+    return self._server_port
+
+  def check_connectivity(self, retries=3):
+    client = getVLLMClient(self._server_port)
+    while self._server_process.poll() is None:
+      try:
+        models = client.models.list().data
+        logging.info('models: %s' % models)
+        if len(models) > 0:
+          self._server_started = True
+          return
+      except:  # pylint: disable=bare-except
+        pass
+      # Sleep while bringing up the process
+      time.sleep(5)
+
+    if retries == 0:
+      self._server_started = False
+      raise Exception(
+          "Failed to start vLLM server, polling process exited with code " +
+          "%s.  Next time a request is tried, the server will be restarted" %
+          self._server_process.poll())
+    else:
+      self.start_server(retries - 1)
+
+
+class VLLMCompletionsModelHandler(ModelHandler[str,
+                                               PredictionResult,
+                                               _VLLMModelServer]):
+  def __init__(
+      self,
+      model_name: str,
+      vllm_server_kwargs: Optional[Dict[str, str]] = None):
+    """Implementation of the ModelHandler interface for vLLM using text as
+    input.
+
+    Example Usage::
+
+      pcoll | RunInference(VLLMModelHandler(model_name='facebook/opt-125m'))
+
+    Args:
+      model_name: The vLLM model. See
+        https://docs.vllm.ai/en/latest/models/supported_models.html for
+        supported models.
+      vllm_server_kwargs: Any additional kwargs to be passed into your vllm
+        server when it is being created. Will be invoked using
+        `python -m vllm.entrypoints.openai.api_serverv <beam provided args>
+        <vllm_server_kwargs>`. For example, you could pass
+        `{'echo': 'true'}` to prepend new messages with the previous message.
+        For a list of possible kwargs, see
+        
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-completions-api
+    """
+    self._model_name = model_name
+    self._vllm_server_kwargs: Dict[str, str] = vllm_server_kwargs or {}
+    self._env_vars = {}
+
+  def load_model(self) -> _VLLMModelServer:
+    return _VLLMModelServer(self._model_name, self._vllm_server_kwargs)
+
+  def run_inference(
+      self,
+      batch: Sequence[str],
+      model: _VLLMModelServer,
+      inference_args: Optional[Dict[str, Any]] = None
+  ) -> Iterable[PredictionResult]:
+    """Runs inferences on a batch of text strings.
+
+    Args:
+      batch: A sequence of examples as text strings.
+      model: A _VLLMModelServer containing info for connecting to the server.
+      inference_args: Any additional arguments for an inference.
+
+    Returns:
+      An Iterable of type PredictionResult.
+    """
+    client = getVLLMClient(model.get_server_port())
+    inference_args = inference_args or {}
+    predictions = []
+    # TODO(https://github.com/apache/beam/issues/32528): We should add support
+    # for taking in batches and doing a bunch of async calls. That will end up
+    # being more efficient when we can do in bundle batching.
+    for prompt in batch:
+      try:
+        completion = client.completions.create(
+            model=self._model_name, prompt=prompt, **inference_args)
+        predictions.append(completion)
+      except Exception as e:
+        model.check_connectivity()
+        raise e
+
+    return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+  def share_model_across_processes(self) -> bool:
+    return True
+
+
+class VLLMChatModelHandler(ModelHandler[Sequence[OpenAIChatMessage],
+                                        PredictionResult,
+                                        _VLLMModelServer]):
+  def __init__(
+      self,
+      model_name: str,
+      chat_template_path: Optional[str] = None,
+      vllm_server_kwargs: Optional[Dict[str, str]] = None):
+    """ Implementation of the ModelHandler interface for vLLM using previous
+    messages as input.
+
+    Example Usage::
+
+      pcoll | RunInference(VLLMModelHandler(model_name='facebook/opt-125m'))
+
+    Args:
+      model_name: The vLLM model. See
+        https://docs.vllm.ai/en/latest/models/supported_models.html for
+        supported models.
+      chat_template_path: Path to a chat template. This file must be accessible
+        from your runner's execution environment, so it is recommended to use
+        a cloud based file storage system (e.g. Google Cloud Storage).
+        For info on chat templates, see:
+        
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#chat-template
+      vllm_server_kwargs: Any additional kwargs to be passed into your vllm
+        server when it is being created. Will be invoked using
+        `python -m vllm.entrypoints.openai.api_serverv <beam provided args>
+        <vllm_server_kwargs>`. For example, you could pass
+        `{'echo': 'true'}` to prepend new messages with the previous message.
+        For a list of possible kwargs, see
+        
https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#extra-parameters-for-chat-api
+    """
+    self._model_name = model_name
+    self._vllm_server_kwargs: Dict[str, str] = vllm_server_kwargs or {}
+    self._env_vars = {}
+    self._chat_template_path = chat_template_path
+    self._chat_file = f'template-{uuid.uuid4().hex}.jinja'
+
+  def load_model(self) -> _VLLMModelServer:
+    chat_template_contents = ''
+    if self._chat_template_path is not None:
+      local_chat_template_path = os.path.join(os.getcwd(), self._chat_file)
+      if not os.path.exists(local_chat_template_path):
+        with FileSystems.open(self._chat_template_path) as fin:
+          chat_template_contents = fin.read().decode()
+        with open(local_chat_template_path, 'a') as f:
+          f.write(chat_template_contents)
+      self._vllm_server_kwargs['chat_template'] = local_chat_template_path
+
+    return _VLLMModelServer(self._model_name, self._vllm_server_kwargs)
+
+  def run_inference(
+      self,
+      batch: Sequence[Sequence[OpenAIChatMessage]],
+      model: _VLLMModelServer,
+      inference_args: Optional[Dict[str, Any]] = None
+  ) -> Iterable[PredictionResult]:
+    """Runs inferences on a batch of text strings.
+
+    Args:
+      batch: A sequence of examples as OpenAI messages.
+      model: A _VLLMModelServer for connecting to the spun up server.
+      inference_args: Any additional arguments for an inference.
+
+    Returns:
+      An Iterable of type PredictionResult.
+    """
+    client = getVLLMClient(model.get_server_port())
+    inference_args = inference_args or {}
+    predictions = []
+    # TODO(https://github.com/apache/beam/issues/32528): We should add support
+    # for taking in batches and doing a bunch of async calls. That will end up
+    # being more efficient when we can do in bundle batching.
+    for messages in batch:
+      formatted = []
+      for message in messages:
+        formatted.append({"role": message.role, "content": message.content})
+      try:
+        completion = client.chat.completions.create(
+            model=self._model_name, messages=formatted, **inference_args)
+        predictions.append(completion)
+      except Exception as e:
+        model.check_connectivity()
+        raise e
+
+    return [PredictionResult(x, y) for x, y in zip(batch, predictions)]
+
+  def share_model_across_processes(self) -> bool:
+    return True
diff --git a/sdks/python/setup.py b/sdks/python/setup.py
index ddd7bfea52c..721cb4c1a8d 100644
--- a/sdks/python/setup.py
+++ b/sdks/python/setup.py
@@ -405,6 +405,7 @@ if __name__ == '__main__':
               # https://github.com/sphinx-doc/sphinx/issues/9727
               'docutils==0.17.1',
               'pandas<2.2.0',
+              'openai'
           ],
           'test': [
               'docstring-parser>=0.15,<1.0',
diff --git a/sdks/python/test-suites/dataflow/common.gradle 
b/sdks/python/test-suites/dataflow/common.gradle
index e5d301ecbe1..6bca904c1a6 100644
--- a/sdks/python/test-suites/dataflow/common.gradle
+++ b/sdks/python/test-suites/dataflow/common.gradle
@@ -424,6 +424,39 @@ def tensorRTTests = tasks.create("tensorRTtests") {
  }
 }
 
+def vllmTests = tasks.create("vllmTests") {
+  dependsOn 'installGcpTest'
+  dependsOn ':sdks:python:sdist'
+ doLast {
+  def testOpts = basicPytestOpts
+  def argMap = [
+    "runner": "DataflowRunner",
+    "machine_type":"n1-standard-4",
+    // TODO(https://github.com/apache/beam/issues/22651): Build docker image 
for VLLM tests during Run time.
+    // This would also enable to use wheel "--sdk_location" as other tasks, 
and eliminate distTarBall dependency
+    // declaration for this project.
+    // Right now, this is built from 
https://github.com/apache/beam/blob/master/sdks/python/apache_beam/ml/inference/test_resources/vllm.dockerfile
+    "sdk_container_image": 
"us.gcr.io/apache-beam-testing/python-postcommit-it/vllm:latest",
+    "sdk_location": files(configurations.distTarBall.files).singleFile,
+    "project": "apache-beam-testing",
+    "region": "us-central1",
+    "model": "facebook/opt-125m",
+    "output": "gs://apache-beam-ml/outputs/vllm_predictions.txt",
+    "disk_size_gb": 75
+  ]
+  def cmdArgs = mapToArgString(argMap)
+  // Exec one version with and one version without the chat option
+  exec {
+    executable 'sh'
+    args '-c', ". ${envdir}/bin/activate && pip install openai && python -m 
apache_beam.examples.inference.vllm_text_completion $cmdArgs 
--experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx'"
+  }
+  exec {
+    executable 'sh'
+    args '-c', ". ${envdir}/bin/activate && pip install openai && python -m 
apache_beam.examples.inference.vllm_text_completion $cmdArgs --chat true 
--chat_template 
'gs://apache-beam-ml/additional_files/sample_chat_template.jinja' 
--experiment='worker_accelerator=type:nvidia-tesla-t4;count:1;install-nvidia-driver:5xx'"
+  }
+ }
+}
+
 // Vertex AI RunInference IT tests
 task vertexAIInferenceTest {
   dependsOn 'initializeForDataflowJob'
@@ -521,6 +554,12 @@ project.tasks.register("inferencePostCommitIT") {
   ]
 }
 
+project.tasks.register("inferencePostCommitITPy312") {
+  dependsOn = [
+  'vllmTests',
+  ]
+}
+
 
 // Create cross-language tasks for running tests against Java expansion 
service(s)
 def gcpProject = project.findProperty('gcpProject') ?: 'apache-beam-testing'

Reply via email to