This is an automated email from the ASF dual-hosted git repository.

damccorm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new 972d5d5e0c6 Update multi model notebook to remove workarounds (#27305)
972d5d5e0c6 is described below

commit 972d5d5e0c6d6ab9e35680cff71d6c5bb365afac
Author: Danny McCormick <dannymccorm...@google.com>
AuthorDate: Thu Jun 29 16:43:00 2023 -0400

    Update multi model notebook to remove workarounds (#27305)
    
    * Update multi model notebook to remove workarounds
    
    * Clean up + add conclusion
---
 .../beam-ml/run_inference_multi_model.ipynb        | 294 +++++++++------------
 1 file changed, 118 insertions(+), 176 deletions(-)

diff --git a/examples/notebooks/beam-ml/run_inference_multi_model.ipynb 
b/examples/notebooks/beam-ml/run_inference_multi_model.ipynb
index 9a99ad2cf47..7cd144223ca 100644
--- a/examples/notebooks/beam-ml/run_inference_multi_model.ipynb
+++ b/examples/notebooks/beam-ml/run_inference_multi_model.ipynb
@@ -47,8 +47,7 @@
     {
       "cell_type": "markdown",
       "source": [
-        "# Ensemble model using an image captioning and ranking example",
-        "\n",
+        "# Ensemble model using an image captioning and ranking example\n",
         "<table align=\"left\">\n",
         "  <td>\n",
         "    <a target=\"_blank\" 
href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/run_inference_multi_model.ipynb\";><img
 
src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\";
 />Run in Google Colab</a>\n",
@@ -65,12 +64,12 @@
     {
       "cell_type": "markdown",
       "source": [
-        "A single machine learning model might not be the right solution for 
your task. Often, machine learning model tasks involve aggregating mutliple 
models together to produce one optimal predictive model and to boost 
performance. \n",
-        " \n",
+        "When performing complex tasks like image captioning, using a single 
ML model may not be the best solution.\n",
+        "\n",
         "\n",
         "This notebook shows how to implement a cascade model in Apache Beam 
using the [RunInference 
API](https://beam.apache.org/documentation/sdks/python-machine-learning/). The 
RunInference API enables you to run your Beam transforms as part of your 
pipeline for optimal machine learning inference.\n",
         "\n",
-        "For more information about the RunInference API, review the 
[RunInference 
notebook](https://colab.research.google.com/drive/111USL4VhUa0xt_mKJxl5nC1YLOC8_yF4?usp=sharing#scrollTo=746b67a7-3562-467f-bea3-d8cd18c14927).\n",
+        "For more information about the RunInference API, review the 
[RunInference 
notebook](https://colab.research.google.com/drive/111USL4VhUa0xt_mKJxl5nC1YLOC8_yF4?usp=sharing#scrollTo=746b67a7-3562-467f-bea3-d8cd18c14927)
 or the [Beam ML 
documentation](https://beam.apache.org/documentation/ml/overview/).\n",
         "\n",
         "**Note:** All images are licensed CC-BY, and creators are listed in 
the 
[LICENSE.txt](https://storage.googleapis.com/apache-beam-samples/image_captioning/LICENSE.txt)
 file."
       ],
@@ -94,7 +93,7 @@
         "\n",
         "This example shows how to generate captions on a a large set of 
images. Apache Beam is the ideal tool to handle this workflow. We use two 
models for this task:\n",
         "\n",
-        "* [BLIP](https://github.com/salesforce/BLIP): Generates a set of 
candidate captions for a given image. \n",
+        "* [BLIP](https://github.com/salesforce/BLIP): Generates a set of 
candidate captions for a given image.\n",
         "* [CLIP](https://github.com/openai/CLIP): Ranks the generated 
captions based on accuracy."
       ],
       "metadata": {
@@ -119,7 +118,7 @@
         "* Run inference with BLIP to generate a list of caption 
candidates.\n",
         "* Aggregate the generated captions with their source image.\n",
         "* Preprocess the aggregated image-caption pairs to rank them with 
CLIP.\n",
-        "* Run inference with CLIP to generate the caption ranking. \n",
+        "* Run inference with CLIP to generate the caption ranking.\n",
         "* Print the image names and the captions sorted according to their 
ranking.\n",
         "\n",
         "\n",
@@ -139,13 +138,13 @@
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/";,
-          "height": 440
+          "height": 460
         },
         "id": "3suC5woJLW_N",
-        "outputId": "d2f9f67b-361b-4ae9-f9db-ce2ff9abd509",
+        "outputId": "2b5e78bf-f212-4a77-9325-8808ef024c2e",
         "cellView": "form"
       },
-      "execution_count": null,
+      "execution_count": 1,
       "outputs": [
         {
           "output_type": "execute_result",
@@ -158,7 +157,7 @@
             ]
           },
           "metadata": {},
-          "execution_count": 3
+          "execution_count": 1
         }
       ]
     },
@@ -184,68 +183,34 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 32,
+      "execution_count": 1,
       "metadata": {
-        "id": "tTUZpG9_q-OW",
-        "colab": {
-          "base_uri": "https://localhost:8080/";
-        },
-        "outputId": "9ee6407a-8e4b-4520-fe5d-54a886b6e0b1"
+        "id": "tTUZpG9_q-OW"
       },
-      "outputs": [
-        {
-          "output_type": "stream",
-          "name": "stdout",
-          "text": [
-            "\u001b[K     |████████████████████████████████| 2.1 MB 7.0 MB/s 
\n",
-            "\u001b[2K     
\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.4/3.4 
MB\u001b[0m \u001b[31m47.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     
\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m3.3/3.3 
MB\u001b[0m \u001b[31m90.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     
\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m182.4/182.4 
kB\u001b[0m \u001b[31m21.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     
\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m880.6/880.6 
kB\u001b[0m \u001b[31m69.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[?25h  Preparing metadata (setup.py) ... 
\u001b[?25l\u001b[?25hdone\n",
-            "  Building wheel for sacremoses (setup.py) ... 
\u001b[?25l\u001b[?25hdone\n",
-            "\u001b[33mWARNING: Running pip as the 'root' user can result in 
broken permissions and conflicting behaviour with the system package manager. 
It is recommended to use a virtual environment instead: 
https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n";,
-            "\u001b[2K     
\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m377.0/377.0 
kB\u001b[0m \u001b[31m10.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can 
result in broken permissions and conflicting behaviour with the system package 
manager. It is recommended to use a virtual environment instead: 
https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n";,
-            "\u001b[2K     
\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.1/53.1 
kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can 
result in broken permissions and conflicting behaviour with the system package 
manager. It is recommended to use a virtual environment instead: 
https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n";,
-            "\u001b[2K     
\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m6.5/6.5 
MB\u001b[0m \u001b[31m60.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[2K     
\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m12.8/12.8 
MB\u001b[0m \u001b[31m91.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[?25h\u001b[33mWARNING: Running pip as the 'root' user can 
result in broken permissions and conflicting behaviour with the system package 
manager. It is recommended to use a virtual environment instead: 
https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n";,
-            "\u001b[2K     
\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m 
\u001b[32m235.4/235.4 kB\u001b[0m \u001b[31m8.1 MB/s\u001b[0m eta 
\u001b[36m0:00:00\u001b[0m\n",
-            "\u001b[?25h  Installing build dependencies ... 
\u001b[?25l\u001b[?25hdone\n",
-            "  Getting requirements to build wheel ... 
\u001b[?25l\u001b[?25hdone\n",
-            "  Installing backend dependencies ... 
\u001b[?25l\u001b[?25hdone\n",
-            "  Preparing metadata (pyproject.toml) ... 
\u001b[?25l\u001b[?25hdone\n",
-            "  Building wheel for fairscale (pyproject.toml) ... 
\u001b[?25l\u001b[?25hdone\n",
-            "\u001b[33mWARNING: Running pip as the 'root' user can result in 
broken permissions and conflicting behaviour with the system package manager. 
It is recommended to use a virtual environment instead: 
https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n";,
-            "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can 
result in broken permissions and conflicting behaviour with the system package 
manager. It is recommended to use a virtual environment instead: 
https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n";,
-            "\u001b[0m"
-          ]
-        }
-      ],
+      "outputs": [],
       "source": [
         "!pip install --upgrade pip --quiet\n",
-        "!pip install transformers==4.15.0 --quiet\n",
+        "!pip install transformers==4.30.2 --quiet\n",
         "!pip install timm==0.4.12 --quiet\n",
         "!pip install ftfy==6.1.1 --quiet\n",
         "!pip install spacy==3.4.1 --quiet\n",
         "!pip install fairscale==0.4.4 --quiet\n",
-        "!pip install apache_beam[gcp]>=2.40.0  \n",
+        "!pip install apache_beam[gcp]>=2.48.0\n",
         "\n",
         "# To use the newly installed versions, restart the runtime.\n",
-        "exit() "
+        "exit()"
       ]
     },
     {
       "cell_type": "code",
       "source": [
         "import requests\n",
-        "import os \n",
+        "import os\n",
         "import urllib\n",
-        "import json  \n",
+        "import json\n",
         "import io\n",
         "from io import BytesIO\n",
+        "from typing import Sequence\n",
         "from typing import Iterator\n",
         "from typing import Iterable\n",
         "from typing import Tuple\n",
@@ -303,7 +268,7 @@
           "base_uri": "https://localhost:8080/";
         },
         "id": "Ud4sUXV2x8LO",
-        "outputId": "9e12ea04-a347-426f-8145-280a5676e78b"
+        "outputId": "cc814ff8-d424-4880-e006-56803e0508aa"
       },
       "execution_count": 2,
       "outputs": [
@@ -311,7 +276,6 @@
           "output_type": "stream",
           "name": "stdout",
           "text": [
-            "Error: Failed to call git rev-parse --git-dir --show-toplevel: 
\"fatal: not a git repository (or any of the parent directories): .git\\n\"\n",
             "Git LFS initialized.\n",
             "Cloning into 'clip-vit-base-patch32'...\n",
             "remote: Enumerating objects: 51, done.\u001b[K\n",
@@ -362,7 +326,7 @@
           "base_uri": "https://localhost:8080/";
         },
         "id": "g4-6WwqUtxea",
-        "outputId": "3b04b933-aab0-4f5b-c967-ed784125bc6a"
+        "outputId": "29112ca0-f111-48b7-d8cc-a4e04fb7a02b"
       },
       "execution_count": 4,
       "outputs": [
@@ -388,8 +352,8 @@
         "from BLIP.models.blip import blip_decoder\n",
         "\n",
         "!gdown 
'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth'\n",
-        "# The blip model is saved as a checkoint, load it and save it as a 
state dict since RunInference required \n",
-        "# a state dict for model instantiation \n",
+        "# The blip model is saved as a checkpoint, load it and save it as a 
state dict since RunInference required\n",
+        "# a state dict for model instantiation\n",
         "blip_state_dict_path = '/content/BLIP/blip_state_dict.pth'\n",
         
"torch.save(torch.load('/content/BLIP/model*_base_caption.pth')['model'], 
blip_state_dict_path)"
       ],
@@ -398,7 +362,7 @@
           "base_uri": "https://localhost:8080/";
         },
         "id": "GCvOP_iZh41c",
-        "outputId": "224c22b1-eda6-463c-c926-1341ec9edef8"
+        "outputId": "a96f0ff5-cdf7-4394-be6e-d5bfca2f3a1f"
       },
       "execution_count": 5,
       "outputs": [
@@ -409,7 +373,7 @@
             "Downloading...\n",
             "From: 
https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model*_base_caption.pth\n";,
             "To: /content/BLIP/model*_base_caption.pth\n",
-            "100% 896M/896M [00:04<00:00, 198MB/s] \n"
+            "100% 896M/896M [00:04<00:00, 198MB/s]\n"
           ]
         }
       ]
@@ -500,9 +464,9 @@
         "\n",
         "  \"\"\"\n",
         "  Process the raw image input to a format suitable for BLIP 
inference. The processed\n",
-        "  images are duplicated to the number of desired captions per image. 
\n",
+        "  images are duplicated to the number of desired captions per 
image.\n",
         "\n",
-        "  Preprocessing transformation taken from: \n",
+        "  Preprocessing transformation taken from:\n",
         "  
https://github.com/salesforce/BLIP/blob/d10be550b2974e17ea72e74edc7948c9e5eab884/predict.py\n";,
         "  \"\"\"\n",
         "\n",
@@ -510,7 +474,7 @@
         "    self._captions_per_image = captions_per_image\n",
         "\n",
         "  def setup(self):\n",
-        "    \n",
+        "\n",
         "    # Initialize the image transformer.\n",
         "    self._transform = transforms.Compose([\n",
         "      transforms.Resize((384, 
384),interpolation=InterpolationMode.BICUBIC),\n",
@@ -519,7 +483,7 @@
         "    ])\n",
         "\n",
         "  def process(self, element):\n",
-        "    image_url, image = element \n",
+        "    image_url, image = element\n",
         "    # The following lines provide a workaround to turn off 
BatchElements.\n",
         "    preprocessed_img = self._transform(image).unsqueeze(0)\n",
         "    preprocessed_img = 
preprocessed_img.repeat(self._captions_per_image, 1, 1, 1)\n",
@@ -533,7 +497,7 @@
         "  Process the PredictionResult to get the generated image captions\n",
         "  \"\"\"\n",
         "  def process(self, element : Tuple[str, 
Iterable[PredictionResult]]):\n",
-        "    image_url, prediction = element \n",
+        "    image_url, prediction = element\n",
         "\n",
         "    return [(image_url, prediction.inference)]"
       ],
@@ -546,7 +510,7 @@
     {
       "cell_type": "markdown",
       "source": [
-        "### Define CLIP functions \n",
+        "### Define CLIP functions\n",
         "\n",
         "Define the preprocessing and postprocessing functions for CLIP."
       ],
@@ -560,9 +524,9 @@
         "class PreprocessCLIPInput(beam.DoFn):\n",
         "\n",
         "  \"\"\"\n",
-        "  Process the image-caption pair to a format suitable for CLIP 
inference. \n",
+        "  Process the image-caption pair to a format suitable for CLIP 
inference.\n",
         "\n",
-        "  After grouping the raw images with the generated captions, we need 
to \n",
+        "  After grouping the raw images with the generated captions, we need 
to\n",
         "  preprocess them before passing them to the ranking stage (CLIP 
model).\n",
         "  \"\"\"\n",
         "\n",
@@ -572,12 +536,12 @@
         "               merges_file_config_path: str):\n",
         "\n",
         "    self._feature_extractor_config_path = 
feature_extractor_config_path\n",
-        "    self._tokenizer_vocab_config_path = tokenizer_vocab_config_path 
\n",
+        "    self._tokenizer_vocab_config_path = 
tokenizer_vocab_config_path\n",
         "    self._merges_file_config_path = merges_file_config_path\n",
         "\n",
         "\n",
         "  def setup(self):\n",
-        "    \n",
+        "\n",
         "    # Initialize the CLIP feature extractor.\n",
         "    feature_extractor_config = 
CLIPConfig.from_pretrained(self._feature_extractor_config_path)\n",
         "    feature_extractor = 
CLIPFeatureExtractor(feature_extractor_config)\n",
@@ -585,14 +549,14 @@
         "    # Initialize the CLIP tokenizer.\n",
         "    tokenizer = CLIPTokenizer(self._tokenizer_vocab_config_path,\n",
         "                              self._merges_file_config_path)\n",
-        "    \n",
+        "\n",
         "    # Initialize the CLIP processor used to process the image-caption 
pair.\n",
         "    self._processor = 
CLIPProcessor(feature_extractor=feature_extractor,\n",
         "                                    tokenizer=tokenizer)\n",
         "\n",
         "  def process(self, element: Tuple[str, Dict[str, List[Any]]]):\n",
         "\n",
-        "    image_url, image_captions_pair = element \n",
+        "    image_url, image_captions_pair = element\n",
         "    # Unpack the image and captions after grouping them with 
'CoGroupByKey()'.\n",
         "    image = image_captions_pair['image'][0]\n",
         "    captions = image_captions_pair['captions'][0]\n",
@@ -600,7 +564,7 @@
         "                                              text = captions,\n",
         "                                              
return_tensors=\"pt\",\n",
         "                                              padding=True)\n",
-        "    \n",
+        "\n",
         "    image_url_caption_pair = (image_url, captions)\n",
         "    return [(image_url_caption_pair, preprocessed_clip_input)]\n",
         "\n",
@@ -612,7 +576,7 @@
         "  The logits are the output of the CLIP model. Here, we apply a 
softmax activation\n",
         "  function to the logits to get the probabilistic distribution of the 
relevance\n",
         "  of each caption to the target image. After that, we sort the 
captions in descending\n",
-        "  order with respect to the probabilities as a caption-probability 
pair. \n",
+        "  order with respect to the probabilities as a caption-probability 
pair.\n",
         "  \"\"\"\n",
         "\n",
         "  def process(self, element : Tuple[Tuple[str, List[str]], 
Iterable[PredictionResult]]):\n",
@@ -642,7 +606,9 @@
     {
       "cell_type": "markdown",
       "source": [
-        "Use a `KeyedModelHandler` for both models to attach a key to the 
general `ModelHandler`.\n",
+        "A `ModelHandler` is Beam's method for defining the configuration 
needed to load and invoke your model. Since both the BLIP and CLIP models use 
Pytorch and take KeyedTensors as inputs, we will use 
`PytorchModelHandlerKeyedTensor` for both.\n",
+        "\n",
+        "We will use a `KeyedModelHandler` for both models to attach a key to 
the general `ModelHandler`.\n",
         "The key is used for the following purposes:\n",
         "* To keep a reference to the image that the inference is associated 
with.\n",
         "* To aggregate transforms of different inputs.\n",
@@ -654,36 +620,6 @@
         "id": "BTmSPnjj8M2m"
       }
     },
-    {
-      "cell_type": "code",
-      "source": [
-        "class 
PytorchNoBatchModelHandlerKeyedTensor(PytorchModelHandlerKeyedTensor):\n",
-        "      \"\"\"Wrapper to PytorchModelHandler to limit batch size to 
1.\n",
-        "    The caption strings generated from the BLIP tokenizer might have 
different\n",
-        "    lengths. Different length strings don't work with torch.stack() 
in the current RunInference\n",
-        "    implementation, because stack() requires tensors to be the same 
size.\n",
-        "    Restricting max_batch_size to 1 means there is only 1 example per 
`batch`\n",
-        "    in the run_inference() call.\n",
-        "    \"\"\"\n",
-        "      # The following lines provide a workaround to turn off 
BatchElements.\n",
-        "      def batch_elements_kwargs(self):\n",
-        "          return {'max_batch_size': 1}"
-      ],
-      "metadata": {
-        "id": "OaR02_wxTMpc"
-      },
-      "execution_count": 9,
-      "outputs": []
-    },
-    {
-      "cell_type": "markdown",
-      "source": [
-        "Note that we use a `KeyedModelHandler` for both models to attach a 
key to the general `ModelHandler`. The key is used for aggregation transforms 
of different inputs."
-      ],
-      "metadata": {
-        "id": "gNLRO0EwvcGP"
-      }
-    },
     {
       "cell_type": "markdown",
       "source": [
@@ -713,48 +649,36 @@
     {
       "cell_type": "code",
       "source": [
-        "class BLIPWrapper(torch.nn.Module):\n",
-        "  \"\"\"\n",
-        "   Wrapper around the BLIP model to overwrite the default \"forward\" 
method with the \"generate\" method, because BLIP uses the \n",
-        "  \"generate\" method to produce the image captions.\n",
-        "  \"\"\"\n",
-        "  \n",
-        "  def __init__(self, base_model: blip_decoder, num_beams: int, 
max_length: int,\n",
-        "                min_length: int):\n",
-        "    super().__init__()\n",
-        "    self._model = base_model()\n",
-        "    self._num_beams = num_beams\n",
-        "    self._max_length = max_length\n",
-        "    self._min_length = min_length\n",
-        "\n",
-        "  def forward(self, inputs: torch.Tensor):\n",
-        "    # Squeeze because RunInference adds an extra dimension, which is 
empty.\n",
-        "    # The following lines provide a workaround to turn off 
BatchElements.\n",
-        "    inputs = inputs.squeeze(0)\n",
-        "    captions = self._model.generate(inputs,\n",
-        "                                    sample=True,\n",
-        "                                    num_beams=self._num_beams,\n",
-        "                                    max_length=self._max_length,\n",
-        "                                    min_length=self._min_length)\n",
-        "    return [captions]\n",
-        "\n",
-        "  def load_state_dict(self, state_dict: dict):\n",
-        "    self._model.load_state_dict(state_dict)\n",
-        "\n",
-        "\n",
-        "BLIP_model_handler = PytorchNoBatchModelHandlerKeyedTensor(\n",
+        "def blip_keyed_tensor_inference_fn(\n",
+        "    batch: Sequence[Dict[str, torch.Tensor]],\n",
+        "    model: torch.nn.Module,\n",
+        "    device: str,\n",
+        "    inference_args: Optional[Dict[str, Any]] = None,\n",
+        "    model_id: Optional[str] = None,\n",
+        ") -> Iterable[PredictionResult]:\n",
+        "  # By default, Beam batches inputs for bulk inference and calls 
model(batch)\n",
+        "  # Since we want to call model.generate on a single unbatched input 
(BLIP/CLIP\n",
+        "  # don't handle batched inputs), we define a custom inference 
function.\n",
+        "  captions = model.generate(batch[0]['inputs'],\n",
+        "                            sample=True,\n",
+        "                            num_beams=NUM_BEAMS,\n",
+        "                            max_length=MAX_CAPTION_LENGTH,\n",
+        "                            min_length=MIN_CAPTION_LENGTH)\n",
+        "  return [PredictionResult(batch[0], captions, model_id)]\n",
+        "\n",
+        "\n",
+        "BLIP_model_handler = PytorchModelHandlerKeyedTensor(\n",
         "    state_dict_path=blip_state_dict_path,\n",
-        "    model_class=BLIPWrapper,\n",
-        "    model_params={'base_model': blip_decoder, 'num_beams': 
NUM_BEAMS,\n",
-        "                  'max_length': MAX_CAPTION_LENGTH, 'min_length': 
MIN_CAPTION_LENGTH},\n",
-        "    device='GPU')\n",
+        "    model_class=blip_decoder,\n",
+        "    inference_fn=blip_keyed_tensor_inference_fn,\n",
+        "    max_batch_size=1)\n",
         "\n",
         "BLIP_keyed_model_handler = KeyedModelHandler(BLIP_model_handler)"
       ],
       "metadata": {
         "id": "RCKBJjujVw4q"
       },
-      "execution_count": 11,
+      "execution_count": 10,
       "outputs": []
     },
     {
@@ -771,29 +695,33 @@
     {
       "cell_type": "code",
       "source": [
-        "class CLIPWrapper(CLIPModel):\n",
-        "\n",
-        "  def forward(self, **kwargs: Dict[str, torch.Tensor]):\n",
-        "    # Squeeze because RunInference adds an extra dimension, which is 
empty.\n",
-        "    # The following lines provide a workaround to turn off 
BatchElements.\n",
-        "    kwargs = {key: tensor.squeeze(0) for key, tensor in 
kwargs.items()}\n",
-        "    output = super().forward(**kwargs)\n",
-        "    logits = output.logits_per_image\n",
-        "    return logits\n",
-        "\n",
-        "\n",
-        "CLIP_model_handler = PytorchNoBatchModelHandlerKeyedTensor(\n",
+        "def clip_keyed_tensor_inference_fn(\n",
+        "    batch: Sequence[Dict[str, torch.Tensor]],\n",
+        "    model: torch.nn.Module,\n",
+        "    device: str,\n",
+        "    inference_args: Optional[Dict[str, Any]] = None,\n",
+        "    model_id: Optional[str] = None,\n",
+        ") -> Iterable[PredictionResult]:\n",
+        "  # By default, Beam batches inputs for bulk inference and calls 
model(batch)\n",
+        "  # Since we want to call model on a single unbatched input 
(BLIP/CLIP don't\n",
+        "  # handle batched inputs), we define a custom inference function.\n",
+        "  output = model(**batch[0], **inference_args)\n",
+        "  return [PredictionResult(batch[0], output.logits_per_image[0], 
model_id)]\n",
+        "\n",
+        "\n",
+        "CLIP_model_handler = PytorchModelHandlerKeyedTensor(\n",
         "    state_dict_path=clip_state_dict_path,\n",
-        "    model_class=CLIPWrapper,\n",
+        "    model_class=CLIPModel,\n",
         "    model_params={'config': 
CLIPConfig.from_pretrained(clip_model_config_path)},\n",
-        "    device='GPU')\n",
+        "    inference_fn=clip_keyed_tensor_inference_fn,\n",
+        "    max_batch_size=1)\n",
         "\n",
         "CLIP_keyed_model_handler = KeyedModelHandler(CLIP_model_handler)\n"
       ],
       "metadata": {
         "id": "EJw_OnZ1ZfuH"
       },
-      "execution_count": 12,
+      "execution_count": 11,
       "outputs": []
     },
     {
@@ -817,7 +745,7 @@
       "metadata": {
         "id": "VJwE0bquoXOf"
       },
-      "execution_count": 13,
+      "execution_count": 12,
       "outputs": []
     },
     {
@@ -834,7 +762,7 @@
       "source": [
         "#@title\n",
         "license_txt_url = 
'https://storage.googleapis.com/apache-beam-samples/image_captioning/LICENSE.txt'\n",
-        "license_dict = 
json.loads(urllib.request.urlopen(license_txt_url).read().decode(\"utf-8\")) 
\n",
+        "license_dict = 
json.loads(urllib.request.urlopen(license_txt_url).read().decode(\"utf-8\"))\n",
         "\n",
         "for image_url in images_url:\n",
         "  response = requests.get(image_url)\n",
@@ -855,7 +783,7 @@
         "outputId": "6e771e4e-a76a-4855-b466-976cdf35b506",
         "cellView": "form"
       },
-      "execution_count": 16,
+      "execution_count": null,
       "outputs": [
         {
           "output_type": "display_data",
@@ -918,7 +846,7 @@
       "metadata": {
         "id": "Dcz_M9GW0Kan"
       },
-      "execution_count": 14,
+      "execution_count": 13,
       "outputs": []
     },
     {
@@ -947,13 +875,13 @@
         "with beam.Pipeline() as pipeline:\n",
         "\n",
         "  read_images = (\n",
-        "            pipeline \n",
+        "            pipeline\n",
         "            | \"ReadUrl\" >> beam.Create(images_url)\n",
         "            | \"ReadImages\" >> beam.ParDo(ReadImagesFromUrl()))\n",
         "\n",
         "  blip_caption_generation = (\n",
         "            read_images\n",
-        "            | \"PreprocessBlipInput\" >> 
beam.ParDo(PreprocessBLIPInput(NUM_CAPTIONS_PER_IMAGE)) \n",
+        "            | \"PreprocessBlipInput\" >> 
beam.ParDo(PreprocessBLIPInput(NUM_CAPTIONS_PER_IMAGE))\n",
         "            | \"GenerateCaptions\" >> 
RunInference(BLIP_keyed_model_handler)\n",
         "            | \"PostprocessCaptions\" >> 
beam.ParDo(PostprocessBLIPOutput()))\n",
         "\n",
@@ -966,19 +894,21 @@
         "                    clip_tokenizer_vocab_config_path,\n",
         "                    clip_merges_config_path))\n",
         "            | \"GetRankingLogits\" >> 
RunInference(CLIP_keyed_model_handler)\n",
-        "            | \"RankClipOutput\" >> beam.ParDo(RankCLIPOutput()))\n",
+        "            | \"RankClipOutput\" >> beam.ParDo(RankCLIPOutput())\n",
+        "            )\n",
         "\n",
         "  clip_captions_ranking | \"FormatCaptions\" >> 
beam.ParDo(FormatCaptions(NUM_TOP_CAPTIONS_TO_DISPLAY))\n",
-        "  "
+        ""
       ],
       "metadata": {
         "colab": {
-          "base_uri": "https://localhost:8080/";
+          "base_uri": "https://localhost:8080/";,
+          "height": 428
         },
         "id": "002e-FNbmuB8",
-        "outputId": "49c646f1-8612-433f-b134-ea8af0ff5591"
+        "outputId": "1b540b1e-b146-45d6-f8d3-ccaf461a87b7"
       },
-      "execution_count": 18,
+      "execution_count": 14,
       "outputs": [
         {
           "output_type": "stream",
@@ -986,29 +916,41 @@
           "text": [
             "Image: Paris-sunset\n",
             "\tTop 3 captions ranked by CLIP:\n",
-            "\t\t1: the eiffel tower in paris is silhouetted at sunset. 
(Caption probability: 0.23)\n",
-            "\t\t2: the sun sets over the city of paris, with the eiffel tower 
in the distance. (Caption probability: 0.19)\n",
-            "\t\t3: the sun sets over the eiffel tower in paris. (Caption 
probability: 0.17)\n",
+            "\t\t1: the setting sun is reflected in an orange setting sky over 
paris. (Caption probability: 0.28)\n",
+            "\t\t2: the sun rising above the eiffel tower over paris. (Caption 
probability: 0.23)\n",
+            "\t\t3: the sun setting over the eiffel tower and rooftops. 
(Caption probability: 0.15)\n",
             "\n",
             "\n",
             "Image: Wedges\n",
             "\tTop 3 captions ranked by CLIP:\n",
-            "\t\t1: a basket of baked fries with a sauce in it. (Caption 
probability: 0.60)\n",
-            "\t\t2: cooked french fries with ketchup and dip sitting in 
napkin. (Caption probability: 0.16)\n",
-            "\t\t3: some french fries with dipping sauce on the side. (Caption 
probability: 0.08)\n",
+            "\t\t1: sweet potato fries with ketchup served in bowl. (Caption 
probability: 0.73)\n",
+            "\t\t2: this is a plate of sweet potato fries with ketchup. 
(Caption probability: 0.16)\n",
+            "\t\t3: sweet potato fries and a dipping sauce are on the tray. 
(Caption probability: 0.06)\n",
             "\n",
             "\n",
             "Image: Hamsters\n",
             "\tTop 3 captions ranked by CLIP:\n",
-            "\t\t1: a person petting two small hamsters while in their home. 
(Caption probability: 0.51)\n",
-            "\t\t2: a woman holding two small white baby animals. (Caption 
probability: 0.23)\n",
-            "\t\t3: a hand holding a small mouse that looks tiny. (Caption 
probability: 0.09)\n",
+            "\t\t1: person holding two small animals in their hands. (Caption 
probability: 0.62)\n",
+            "\t\t2: a person's hand holding a small hamster in front of them. 
(Caption probability: 0.20)\n",
+            "\t\t3: a person holding a small animal in their hands. (Caption 
probability: 0.09)\n",
             "\n",
             "\n"
           ]
         }
       ]
     },
+    {
+      "cell_type": "markdown",
+      "source": [
+       "# Conclusion\n",
+       "\n",
+        "After running the pipeline, you can see the captions generated by the 
BLIP model and ranked by the CLIP model with all of our pre/postprocessing 
logic applied.\n",
+       "As you can see, running multi-model inference is easy with the power 
of Beam.\n"
+      ],
+      "metadata": {
+        "id": "gPCMXWgOtM_0"
+      }
+    },
     {
       "cell_type": "markdown",
       "source": [

Reply via email to