gemini-code-assist[bot] commented on code in PR #37186: URL: https://github.com/apache/beam/pull/37186#discussion_r3438734810
########## sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py: ########## @@ -0,0 +1,519 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This pipeline performs image classification using an open-source +PyTorch EfficientNet-B0 model optimized for T4 GPUs. +It reads image URIs from Pub/Sub, decodes and preprocesses them in parallel, +and runs inference with adaptive batch sizing for optimal GPU utilization. +The pipeline targets stable and reproducible performance measurements under +continuous load. +Resources like Pub/Sub topic/subscription cleanup is handled programmatically. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Optional +from typing import Tuple + +import torch +import torch.nn.functional as F + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import PIL.Image as PILImage + +# ============ Utility & Preprocessing ============ + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def decode_and_preprocess(image_bytes: bytes, size: int = 224) -> torch.Tensor: + """Decode bytes->RGB PIL->resize/crop->tensor->normalize.""" + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.thumbnail((256, 256)) + w, h = img.size + left = (w - size) // 2 + top = (h - size) // 2 + img = img.crop( + (max(0, left), max(0, top), min(w, left + size), min(h, top + size))) + + # To tensor [0..1] + import numpy as np + mean = np.array(IMAGENET_MEAN, dtype=np.float32) + std = np.array(IMAGENET_STD, dtype=np.float32) + arr = np.asarray(img).astype("float32") / 255.0 # H,W,3 + # Normalize + arr = (arr - mean) / std + # HWC -> CHW + arr = np.transpose(arr, (2, 0, 1)).astype("float32") + return torch.from_numpy(arr).float() # float32, shape (3,224,224) Review Comment:  ### Critical Bug: Shape Mismatch during Batching Using `img.thumbnail((256, 256))` preserves the aspect ratio of the image. If the input image is not square, one of its dimensions will be smaller than 256. When you subsequently crop it using `(w - size) // 2` where `size = 224`, the resulting cropped image will have a dimension smaller than 224 (e.g., `224 x 128`). Because different images in the dataset have different aspect ratios, they will yield tensors of different shapes (e.g., `[3, 224, 128]` and `[3, 128, 224]`). When `RunInference` tries to batch these tensors using `torch.stack`, it will crash with a shape mismatch error. To fix this, resize the shorter side of the image to 256 while preserving the aspect ratio, and then perform a center crop to exactly `size x size`. ```python def decode_and_preprocess(image_bytes: bytes, size: int = 224) -> torch.Tensor: """Decode bytes->RGB PIL->resize/crop->tensor->normalize.""" with PILImage.open(io.BytesIO(image_bytes)) as img: img = img.convert("RGB") w, h = img.size if w < h: new_w = 256 new_h = int(h * (256 / w)) else: new_h = 256 new_w = int(w * (256 / h)) img = img.resize((new_w, new_h), PILImage.BILINEAR) left = (new_w - size) // 2 top = (new_h - size) // 2 img = img.crop((left, top, left + size, top + size)) # To tensor [0..1] import numpy as np mean = np.array(IMAGENET_MEAN, dtype=np.float32) std = np.array(IMAGENET_STD, dtype=np.float32) arr = np.asarray(img).astype("float32") / 255.0 # H,W,3 # Normalize arr = (arr - mean) / std # HWC -> CHW arr = np.transpose(arr, (2, 0, 1)).astype("float32") return torch.from_numpy(arr).float() # float32, shape (3,224,224) ``` ########## sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py: ########## @@ -0,0 +1,535 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs object detection using an open-source PyTorch +TorchVision detection model (e.g., Faster R-CNN ResNet50 FPN) on GPU. + +It reads image URIs from a GCS input file, decodes and preprocesses images, +runs batched GPU inference via RunInference, post-processes detection outputs, +and writes results to BigQuery. + +The pipeline targets stable and reproducible performance measurements for +GPU inference workloads (no right-fitting; fixed batch size). +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility & Preprocessing ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def decode_to_tens( + image_bytes: bytes, + resize_shorter_side: Optional[int] = None) -> torch.Tensor: + """Decode bytes -> RGB PIL -> optional resize -> float tensor [0..1], CHW. + + Note: TorchVision detection models apply their own normalization internally. + """ + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + + if resize_shorter_side and resize_shorter_side > 0: + w, h = img.size + # Resize so that shorter side == resize_shorter_side, keep aspect ratio. + if w < h: + new_w = resize_shorter_side + new_h = int(h * (resize_shorter_side / float(w))) + else: + new_h = resize_shorter_side + new_w = int(w * (resize_shorter_side / float(h))) + img = img.resize((new_w, new_h)) + + import numpy as np + arr = np.asarray(img).astype("float32") / 255.0 # H,W,3 in [0..1] + arr = np.transpose(arr, (2, 0, 1)) # CHW + return torch.from_numpy(arr) Review Comment:  ### Critical Bug: Varying Image Shapes cause Stacking Crash `PytorchModelHandlerTensor` internally calls `torch.stack(batch)` to create a batched tensor. If the images in the batch have different aspect ratios or sizes, `torch.stack` will raise a `RuntimeError: stack expects each tensor to be equal size` and crash the pipeline. To make the pipeline robust for general datasets, resize all images to a fixed square size (e.g., `640x640` or `800x800`) during preprocessing. ```python def decode_to_tens( image_bytes: bytes, size: Tuple[int, int] = (640, 640)) -> torch.Tensor: """Decode bytes -> RGB PIL -> resize to fixed size -> float tensor [0..1], CHW.""" with PILImage.open(io.BytesIO(image_bytes)) as img: img = img.convert("RGB") img = img.resize(size, PILImage.BILINEAR) import numpy as np arr = np.asarray(img).astype("float32") / 255.0 # H,W,3 in [0..1] arr = np.transpose(arr, (2, 0, 1)) # CHW return torch.from_numpy(arr) ``` ########## sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py: ########## @@ -0,0 +1,519 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This pipeline performs image classification using an open-source +PyTorch EfficientNet-B0 model optimized for T4 GPUs. +It reads image URIs from Pub/Sub, decodes and preprocesses them in parallel, +and runs inference with adaptive batch sizing for optimal GPU utilization. +The pipeline targets stable and reproducible performance measurements under +continuous load. +Resources like Pub/Sub topic/subscription cleanup is handled programmatically. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Optional +from typing import Tuple + +import torch +import torch.nn.functional as F + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.ml.inference.pytorch_inference import PytorchModelHandlerTensor +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import PIL.Image as PILImage + +# ============ Utility & Preprocessing ============ + +IMAGENET_MEAN = [0.485, 0.456, 0.406] +IMAGENET_STD = [0.229, 0.224, 0.225] + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def load_image_from_uri(uri: str) -> bytes: + with FileSystems.open(uri) as f: + return f.read() + + +def decode_and_preprocess(image_bytes: bytes, size: int = 224) -> torch.Tensor: + """Decode bytes->RGB PIL->resize/crop->tensor->normalize.""" + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.thumbnail((256, 256)) + w, h = img.size + left = (w - size) // 2 + top = (h - size) // 2 + img = img.crop( + (max(0, left), max(0, top), min(w, left + size), min(h, top + size))) + + # To tensor [0..1] + import numpy as np + mean = np.array(IMAGENET_MEAN, dtype=np.float32) + std = np.array(IMAGENET_STD, dtype=np.float32) + arr = np.asarray(img).astype("float32") / 255.0 # H,W,3 + # Normalize + arr = (arr - mean) / std + # HWC -> CHW + arr = np.transpose(arr, (2, 0, 1)).astype("float32") + return torch.from_numpy(arr).float() # float32, shape (3,224,224) + + +class MakeKeyDoFn(beam.DoFn): + """Produce (image_id, payload) stable for dedup & BQ insertId.""" + def __init__(self, input_mode: str): + self.input_mode = input_mode + + def process(self, element: str | bytes): + # Input can be raw bytes from Pub/Sub or a GCS URI string, depends on mode + if self.input_mode == "bytes": + # element is bytes message, assume it includes + # {"image_id": "...", "bytes": base64?} or just raw bytes. + import hashlib + b = element if isinstance(element, + (bytes, + bytearray)) else element.encode('utf-8') + image_id = hashlib.sha1(b).hexdigest() + yield image_id, b + else: + # gcs_uris: element is uri string; image_id = sha1(uri) + import hashlib + uri = element.decode("utf-8") if isinstance( + element, (bytes, bytearray)) else str(element) + image_id = hashlib.sha1(uri.encode("utf-8")).hexdigest() + yield image_id, uri + + +class DecodePreprocessDoFn(beam.DoFn): + """Turn (image_id, bytes|uri) -> (image_id, torch.Tensor)""" + def __init__( + self, input_mode: str, image_size: int = 224, decode_threads: int = 4): + self.input_mode = input_mode + self.image_size = image_size + self.decode_threads = decode_threads + + def process(self, kv: Tuple[str, object]): + image_id, payload = kv + start = now_millis() + + try: + if self.input_mode == "bytes": + b = payload if isinstance(payload, + (bytes, bytearray)) else bytes(payload) + else: + uri = payload if isinstance(payload, str) else payload.decode("utf-8") + b = load_image_from_uri(uri) + + tensor = decode_and_preprocess(b, self.image_size) + preprocess_ms = now_millis() - start + yield image_id, {"tensor": tensor, "preprocess_ms": preprocess_ms} + except Exception as e: + logging.warning("Decode failed for %s: %s", image_id, e) + return + + +class PostProcessDoFn(beam.DoFn): + """PredictionResult -> dict row for BQ.""" + def __init__(self, top_k: int, model_name: str): + self.top_k = top_k + self.model_name = model_name + + def process(self, kv: Tuple[str, PredictionResult]): + image_id, pred = kv + + # pred can be PredictionResult OR raw inference object. + inference_obj = pred.inference if hasattr(pred, "inference") else pred + + # inference_obj can be dict {'logits': tensor} OR tensor directly. + if isinstance(inference_obj, dict): + logits = inference_obj.get("logits", None) + if logits is None: + raise ValueError( + f"Unable to find 'logits' in model output. " + f"Available keys: {list(inference_obj.keys())}" + ) + else: + logits = inference_obj + + if not isinstance(logits, torch.Tensor): + logging.warning( + "Unexpected logits type for %s: %s", image_id, type(logits)) + return + + # Ensure shape [1, C] + if logits.ndim == 1: + logits = logits.unsqueeze(0) + + probs = F.softmax(logits, dim=-1) # [B, C] + values, indices = torch.topk( + probs, k=min(self.top_k, probs.shape[-1]), dim=-1 + ) + + topk = [{ + "class_id": int(idx.item()), "score": float(val.item()) + } for idx, val in zip(indices[0], values[0])] + + yield { + "image_id": image_id, + "model_name": self.model_name, + "topk": json.dumps(topk), + "infer_ms": now_millis(), + } + + +# ============ Args & Helpers ============ + + +def parse_known_args(argv): + parser = argparse.ArgumentParser() + # I/O & runtime + parser.add_argument( + '--project', default='apache-beam-testing', help='GCP project ID') + parser.add_argument( + '--mode', default='streaming', choices=['streaming', 'batch']) + parser.add_argument( + '--output_table', + required=True, + help='BigQuery output table: dataset.table') + parser.add_argument( + '--publish_to_big_query', default='true', choices=['true', 'false']) + parser.add_argument( + '--input_mode', default='gcs_uris', choices=['gcs_uris', 'bytes']) + parser.add_argument( + '--input', + required=True, + help='GCS path to file with URIs (for load) OR unused for bytes') + parser.add_argument( + '--pubsub_topic', + default='projects/apache-beam-testing/topics/images_topic') + parser.add_argument( + '--pubsub_subscription', + default='projects/apache-beam-testing/subscriptions/images_subscription') + parser.add_argument( + '--feeder_start_delay_sec', + type=int, + default=900, + help=( + 'Delay before starting the feeder pipeline that reads URIs from GCS ' + 'and publishes them to Pub/Sub. This delay allows the main streaming ' + 'pipeline workers to start and scale before data ingestion begins.'), + ) + + # Model & inference + parser.add_argument( + '--pretrained_model_name', + default='efficientnet_b0', + help='OSS model name (e.g., efficientnet_b0|mobilenetv3_large_100)') + parser.add_argument( + '--model_state_dict_path', + default=None, + help='Optional state_dict to load') + parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU']) + parser.add_argument('--image_size', type=int, default=224) + parser.add_argument('--top_k', type=int, default=5) + parser.add_argument( + '--inference_batch_size', + default='auto', + help='int or "auto"; auto tries 64→32→16') + + # Windows + parser.add_argument('--window_sec', type=int, default=60) + parser.add_argument('--trigger_proc_time_sec', type=int, default=30) + + known_args, pipeline_args = parser.parse_known_args(argv) + return known_args, pipeline_args + + +def ensure_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + publisher.get_topic(request={"topic": full_topic_path}) + except NotFound: + publisher.create_topic(name=full_topic_path) + + try: + subscriber.get_subscription( + request={"subscription": full_subscription_path}) + except NotFound: + subscriber.create_subscription( + name=full_subscription_path, topic=full_topic_path) + + +def cleanup_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + subscriber.delete_subscription( + request={"subscription": full_subscription_path}) + logging.info(f"Deleted subscription: {subscription_name}") + except NotFound: + logging.info(f"Subscription already deleted: {subscription_name}") + + try: + publisher.delete_topic(request={"topic": full_topic_path}) + logging.info(f"Deleted topic: {topic_name}") + except NotFound: + logging.info(f"Topic already deleted: {topic_name}") + + +def override_or_add(args, flag, value): + if flag in args: + idx = args.index(flag) + args[idx + 1] = str(value) + else: + args.extend([flag, str(value)]) + + +# ============ Model factory (timm) ============ + + +def create_timm_m(model_name: str, num_classes: int = 1000): + import timm + model = timm.create_model( + model_name, pretrained=True, num_classes=num_classes) + model.eval() + return model + + +def pick_batch_size(arg: str) -> Optional[int]: + if isinstance(arg, str) and arg.lower() == 'auto': + return None + try: + return int(arg) + except Exception: + return None + + +# ============ Load pipeline ============ + + +def run_load_pipeline(known_args, pipeline_args): + """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" + # enforce smaller/CPU-only defaults for feeder + override_or_add(pipeline_args, '--device', 'CPU') + override_or_add(pipeline_args, '--num_workers', '5') + override_or_add(pipeline_args, '--max_num_workers', '10') + override_or_add( + pipeline_args, '--job_name', f"images-load-pubsub-{int(time.time())}") + override_or_add(pipeline_args, '--project', known_args.project) + pipeline_args = [ + arg for arg in pipeline_args if not arg.startswith("--experiments") + ] + + pipeline_options = PipelineOptions(pipeline_args) + pipeline = beam.Pipeline(options=pipeline_options) + + _ = ( + pipeline + | 'ReadGCSFile' >> beam.io.ReadFromText(known_args.input) + | 'FilterEmpty' >> beam.Filter(lambda line: line.strip()) + | 'ToBytes' >> beam.Map(lambda line: line.encode('utf-8')) + | 'ToPubSub' >> beam.io.WriteToPubSub(topic=known_args.pubsub_topic)) + return pipeline.run() + + +# ============ Main pipeline ============ + + +def run( + argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: + known_args, pipeline_args = parse_known_args(argv) + + if known_args.mode == 'streaming': + ensure_pubsub_resources( + project=known_args.project, + topic_path=known_args.pubsub_topic, + subscription_path=known_args.pubsub_subscription) + + # Start feeder thread that reads URIs from GCS and fills Pub/Sub. + # Delay is used to allow the main streaming pipeline workers to start + # and autoscale before the feeder pipeline begins publishing messages. + threading.Thread( + target=lambda: ( + time.sleep(known_args.feeder_start_delay_sec), run_load_pipeline( + known_args, pipeline_args)), + daemon=True).start() + + # StandardOptions + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + pipeline_options.view_as(StandardOptions).streaming = ( + known_args.mode == 'streaming') + + # Build model handler with right-fitting batch size + desired_batch = pick_batch_size(known_args.inference_batch_size) + tried = [64, 32, 16] if desired_batch is None else [desired_batch] + + # Device + device = 'GPU' if known_args.device.upper() == 'GPU' else 'CPU' + + bs_ok = None + last_err = None + for bs in tried: + try: + model_handler = PytorchModelHandlerTensor( + model_class=lambda: create_timm_m(known_args.pretrained_model_name), + model_params={}, + state_dict_path=known_args.model_state_dict_path, + device=device, + inference_batch_size=bs + if bs is not None else 64, # start guess for warmup + ) + # quick warmup to validate memory (single dummy tensor) + dummy = torch.zeros((3, known_args.image_size, known_args.image_size), + dtype=torch.float32) + _ = model_handler.load_model() # ensures weights on device + with torch.no_grad(): + mdl = model_handler._model + mdl(torch.unsqueeze(dummy, 0)) + bs_ok = bs if bs is not None else 64 + break + except RuntimeError as e: + last_err = e + logging.warning("Batch size %s failed during warmup: %s", bs, e) + continue + + if bs_ok is None: + logging.warning( + "Falling back to batch_size=8 due to previous errors: %s", last_err) + bs_ok = 8 + model_handler = PytorchModelHandlerTensor( + model_class=lambda: create_timm_m(known_args.pretrained_model_name), + model_params={}, + state_dict_path=known_args.model_state_dict_path, + device=device, + inference_batch_size=bs_ok, + ) + Review Comment:  ### Architectural Flaw: Warmup / Right-fitting Executed on Submission Client Currently, the warmup loop to determine the optimal batch size runs directly inside the `run()` function on the submission client (driver) before the pipeline is sent to the runner. This has two major issues: 1. If the client machine does not have a GPU (which is typical for CI/CD runners or local submission environments), loading the model with `device='GPU'` will fail immediately during pipeline submission. 2. The warmup selects a batch size based on the client's hardware instead of the actual worker's hardware (e.g., Tesla T4 on Google Cloud). To fix this, subclass `PytorchModelHandlerTensor` and perform the warmup/right-fitting logic inside the `load_model` method, which executes directly on the workers. ```python class RightFitPytorchModelHandler(PytorchModelHandlerTensor): def __init__(self, tried_batch_sizes, image_size, *args, **kwargs): super().__init__(*args, **kwargs) self.tried_batch_sizes = tried_batch_sizes self.image_size = image_size def load_model(self): last_err = None for bs in self.tried_batch_sizes: try: self._inference_batch_size = bs model = super().load_model() # Warmup to validate memory on the actual worker dummy = torch.zeros((3, self.image_size, self.image_size), dtype=torch.float32).to(self.device) with torch.no_grad(): model(torch.unsqueeze(dummy, 0)) return model except RuntimeError as e: last_err = e logging.warning("Batch size %s failed during warmup: %s", bs, e) continue # Fallback self._inference_batch_size = 8 return super().load_model() ``` ########## sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py: ########## @@ -0,0 +1,637 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs image captioning using a multi-model approach: +BLIP generates candidate captions, CLIP ranks them by image-text similarity. + +The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP +caption generation in batches on GPU, then runs CLIP ranking in batches on GPU. +Results are written to BigQuery. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def decode_pil(image_bytes: bytes) -> PILImage.Image: + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.load() + return img + + +# ============ DoFns ============ + + +class MakeKeyDoFn(beam.DoFn): + """Produce (uri, uri) so the URI is used as the stable key.""" + def process(self, element: str): + uri = element + yield uri, uri + + +class ReadImageBytesDoFn(beam.DoFn): + """Turn (uri, uri) -> (uri, dict(image_bytes)).""" + def process(self, kv: Tuple[str, str]): + uri, _ = kv + try: + with FileSystems.open(uri) as f: + image_bytes = f.read() + yield uri, {"image_bytes": image_bytes} + except OSError as e: + logging.warning("Failed to read image %s: %s", uri, e) + return + + +class DecodeImageDoFn(beam.DoFn): + """Turn (uri, dict(image_bytes)) -> (uri, dict(image)).""" + def process(self, kv: Tuple[str, Dict[str, Any]]): + uri, value = kv + image_bytes = value["image_bytes"] + + try: + image = decode_pil(image_bytes) + except Exception as e: + logging.warning("Failed to decode image %s: %s", uri, e) + image = PILImage.new("RGB", (224, 224), color=(0, 0, 0)) + + yield uri, {"image": image} + + +class PostProcessDoFn(beam.DoFn): + """Final PredictionResult -> row for BigQuery.""" + def __init__(self, blip_name: str, clip_name: str): + self.blip_name = blip_name + self.clip_name = clip_name + + def process(self, kv: Tuple[str, PredictionResult]): + uri, pred = kv + if hasattr(pred, "inference"): + inf = pred.inference or {} + else: + inf = pred + # Expected inference fields from CLIP handler: + # best_caption, best_score, candidates, scores, blip_ms, clip_ms, total_ms + best_caption = inf.get("best_caption", "") + best_score = inf.get("best_score", None) + candidates = inf.get("candidates", []) + scores = inf.get("scores", []) + blip_ms = inf.get("blip_ms", None) + clip_ms = inf.get("clip_ms", None) + total_ms = inf.get("total_ms", None) + + yield { + "image_id": uri, + "blip_model": self.blip_name, + "clip_model": self.clip_name, + "best_caption": best_caption, + "best_score": float(best_score) if best_score is not None else None, + "candidates": json.dumps(candidates), + "scores": json.dumps(scores), + "blip_ms": int(blip_ms) if blip_ms is not None else None, + "clip_ms": int(clip_ms) if clip_ms is not None else None, + "total_ms": int(total_ms) if total_ms is not None else None, + "infer_ms": now_millis(), + } + + +# ============ Model Handlers ============ + + +class BlipCaptionModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + num_captions: int, + max_new_tokens: int, + num_beams: int): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.num_captions = num_captions + self.max_new_tokens = max_new_tokens + self.num_beams = num_beams + + def load_model(self): + from transformers import BlipForConditionalGeneration, BlipProcessor + processor = BlipProcessor.from_pretrained(self.model_name) + model = BlipForConditionalGeneration.from_pretrained(self.model_name) + model.to(self.device) + model.eval() + return (model, processor) + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model_bundle, inference_args=None): + + model, processor = model_bundle + start = now_millis() + + images = [x["image"] for x in batch] + + # Processor makes pixel_values + inputs = processor(images=images, return_tensors="pt") + pixel_values = inputs["pixel_values"].to(self.device) + + # Generate captions + # We use num_return_sequences to generate multiple candidates per image. + # Note: this will produce (B * num_captions) sequences. + with torch.no_grad(): + generated_ids = model.generate( + pixel_values=pixel_values, + max_new_tokens=self.max_new_tokens, + num_beams=max(self.num_beams, self.num_captions), + num_return_sequences=self.num_captions, + do_sample=False, + ) + + captions_all = processor.batch_decode( + generated_ids, skip_special_tokens=True) + + # Group candidates per image + candidates_per_image = [] + idx = 0 + for _ in range(len(batch)): + candidates_per_image.append(captions_all[idx:idx + self.num_captions]) + idx += self.num_captions + + blip_ms = now_millis() - start + + results = [] + for i in range(len(batch)): + results.append({ + "image": images[i], + "candidates": candidates_per_image[i], + "blip_ms": blip_ms, + }) + return results + + def get_metrics_namespace(self) -> str: + return "blip_captioning" + + +class ClipRankModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + score_normalize: bool): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.score_normalize = score_normalize + + def load_model(self): + from transformers import CLIPModel, CLIPProcessor + processor = CLIPProcessor.from_pretrained(self.model_name) + model = CLIPModel.from_pretrained(self.model_name) + model.to(self.device) + model.eval() + return (model, processor) + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model_bundle, inference_args=None): + + model, processor = model_bundle + start_batch = now_millis() + + # Flat lists for a single batched CLIP forward pass + images: List[PILImage.Image] = [] + texts: List[str] = [] + offsets: List[Tuple[int, int]] = [] + # per element -> [start, end) in flat arrays + candidates_list: List[List[str]] = [] + blip_ms_list: List[Optional[int]] = [] + + for x in batch: + img = x["image"] + candidates = [str(c) for c in (x.get("candidates", []) or [])] + candidates_list.append(candidates) + blip_ms_list.append(x.get("blip_ms", None)) + + start_i = len(texts) + for c in candidates: + images.append(img) + texts.append(c) + end_i = len(texts) + offsets.append((start_i, end_i)) + + results: List[Dict[str, Any]] = [] + + # Fast path: no candidates at all + if not texts: + for blip_ms in blip_ms_list: + total_ms = int(blip_ms) if blip_ms is not None else None + results.append({ + "best_caption": "", + "best_score": None, + "candidates": [], + "scores": [], + "blip_ms": blip_ms, + "clip_ms": 0, + "total_ms": total_ms, + }) + return results + + with torch.no_grad(): + inputs = processor( + text=texts, + images=images, + return_tensors="pt", + padding=True, + truncation=True, + ) + inputs = { + k: (v.to(self.device) if torch.is_tensor(v) else v) + for k, v in inputs.items() + } + + # avoid NxN logits inside CLIPModel.forward() + img = model.get_image_features( + pixel_values=inputs["pixel_values"]) # [N, D] + txt = model.get_text_features( + input_ids=inputs["input_ids"], + attention_mask=inputs.get("attention_mask"), + ) # [N, D] + + img = img / img.norm(dim=-1, keepdim=True) + txt = txt / txt.norm(dim=-1, keepdim=True) + + logit_scale = model.logit_scale.exp() # scalar tensor + pair_scores = (img * txt).sum(dim=-1) * logit_scale # [N] + pair_scores_cpu = pair_scores.detach().cpu().tolist() Review Comment:  ### Performance Bottleneck: Redundant Image Encoding in CLIP In the current implementation, the same image is duplicated `num_captions` (default: 5) times in the `images` list and passed to `processor` and `model.get_image_features`. Since the image encoder (ViT) is computationally heavy, encoding the exact same image multiple times per batch element is highly inefficient. We can optimize this by encoding only the **unique** active images once, and then repeating/aligning their features to match the flat candidate text features before computing the cosine similarity. ```python with torch.no_grad(): # Extract unique images to avoid redundant encoding unique_images = [] for x in batch: if x.get("candidates"): unique_images.append(x["image"]) image_inputs = processor(images=unique_images, return_tensors="pt") image_inputs = {k: v.to(self.device) for k, v in image_inputs.items()} text_inputs = processor( text=texts, return_tensors="pt", padding=True, truncation=True, ) text_inputs = {k: v.to(self.device) for k, v in text_inputs.items()} img_features = model.get_image_features(**image_inputs) # [B_active, D] txt_features = model.get_text_features(**text_inputs) # [total_pairs, D] img_features = img_features / img_features.norm(dim=-1, keepdim=True) txt_features = txt_features / txt_features.norm(dim=-1, keepdim=True) logit_scale = model.logit_scale.exp() # scalar tensor # Align image features with text features repeated_img_features = [] active_idx = 0 for start_i, end_i in offsets: if start_i != end_i: num_candidates = end_i - start_i repeated_img_features.append(img_features[active_idx].repeat(num_candidates, 1)) active_idx += 1 if repeated_img_features: repeated_img_features = torch.cat(repeated_img_features, dim=0) pair_scores = (repeated_img_features * txt_features).sum(dim=-1) * logit_scale pair_scores_cpu = pair_scores.detach().cpu().tolist() else: pair_scores_cpu = [] ``` ########## sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py: ########## @@ -0,0 +1,637 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""This pipeline performs image captioning using a multi-model approach: +BLIP generates candidate captions, CLIP ranks them by image-text similarity. + +The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP +caption generation in batches on GPU, then runs CLIP ranking in batches on GPU. +Results are written to BigQuery. +""" + +import argparse +import io +import json +import logging +import threading +import time +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Tuple + +import apache_beam as beam +from apache_beam.io.filesystems import FileSystems +from apache_beam.ml.inference.base import KeyedModelHandler +from apache_beam.ml.inference.base import ModelHandler +from apache_beam.ml.inference.base import PredictionResult +from apache_beam.ml.inference.base import RunInference +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import SetupOptions +from apache_beam.options.pipeline_options import StandardOptions +from apache_beam.runners.runner import PipelineResult +from apache_beam.transforms import window + +from google.api_core.exceptions import NotFound +from google.cloud import pubsub_v1 +import torch +import PIL.Image as PILImage + +# ============ Utility ============ + + +def now_millis() -> int: + return int(time.time() * 1000) + + +def decode_pil(image_bytes: bytes) -> PILImage.Image: + with PILImage.open(io.BytesIO(image_bytes)) as img: + img = img.convert("RGB") + img.load() + return img + + +# ============ DoFns ============ + + +class MakeKeyDoFn(beam.DoFn): + """Produce (uri, uri) so the URI is used as the stable key.""" + def process(self, element: str): + uri = element + yield uri, uri + + +class ReadImageBytesDoFn(beam.DoFn): + """Turn (uri, uri) -> (uri, dict(image_bytes)).""" + def process(self, kv: Tuple[str, str]): + uri, _ = kv + try: + with FileSystems.open(uri) as f: + image_bytes = f.read() + yield uri, {"image_bytes": image_bytes} + except OSError as e: + logging.warning("Failed to read image %s: %s", uri, e) + return + + +class DecodeImageDoFn(beam.DoFn): + """Turn (uri, dict(image_bytes)) -> (uri, dict(image)).""" + def process(self, kv: Tuple[str, Dict[str, Any]]): + uri, value = kv + image_bytes = value["image_bytes"] + + try: + image = decode_pil(image_bytes) + except Exception as e: + logging.warning("Failed to decode image %s: %s", uri, e) + image = PILImage.new("RGB", (224, 224), color=(0, 0, 0)) + + yield uri, {"image": image} + + +class PostProcessDoFn(beam.DoFn): + """Final PredictionResult -> row for BigQuery.""" + def __init__(self, blip_name: str, clip_name: str): + self.blip_name = blip_name + self.clip_name = clip_name + + def process(self, kv: Tuple[str, PredictionResult]): + uri, pred = kv + if hasattr(pred, "inference"): + inf = pred.inference or {} + else: + inf = pred + # Expected inference fields from CLIP handler: + # best_caption, best_score, candidates, scores, blip_ms, clip_ms, total_ms + best_caption = inf.get("best_caption", "") + best_score = inf.get("best_score", None) + candidates = inf.get("candidates", []) + scores = inf.get("scores", []) + blip_ms = inf.get("blip_ms", None) + clip_ms = inf.get("clip_ms", None) + total_ms = inf.get("total_ms", None) + + yield { + "image_id": uri, + "blip_model": self.blip_name, + "clip_model": self.clip_name, + "best_caption": best_caption, + "best_score": float(best_score) if best_score is not None else None, + "candidates": json.dumps(candidates), + "scores": json.dumps(scores), + "blip_ms": int(blip_ms) if blip_ms is not None else None, + "clip_ms": int(clip_ms) if clip_ms is not None else None, + "total_ms": int(total_ms) if total_ms is not None else None, + "infer_ms": now_millis(), + } + + +# ============ Model Handlers ============ + + +class BlipCaptionModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + num_captions: int, + max_new_tokens: int, + num_beams: int): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.num_captions = num_captions + self.max_new_tokens = max_new_tokens + self.num_beams = num_beams + + def load_model(self): + from transformers import BlipForConditionalGeneration, BlipProcessor + processor = BlipProcessor.from_pretrained(self.model_name) + model = BlipForConditionalGeneration.from_pretrained(self.model_name) + model.to(self.device) + model.eval() + return (model, processor) + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model_bundle, inference_args=None): + + model, processor = model_bundle + start = now_millis() + + images = [x["image"] for x in batch] + + # Processor makes pixel_values + inputs = processor(images=images, return_tensors="pt") + pixel_values = inputs["pixel_values"].to(self.device) + + # Generate captions + # We use num_return_sequences to generate multiple candidates per image. + # Note: this will produce (B * num_captions) sequences. + with torch.no_grad(): + generated_ids = model.generate( + pixel_values=pixel_values, + max_new_tokens=self.max_new_tokens, + num_beams=max(self.num_beams, self.num_captions), + num_return_sequences=self.num_captions, + do_sample=False, + ) + + captions_all = processor.batch_decode( + generated_ids, skip_special_tokens=True) + + # Group candidates per image + candidates_per_image = [] + idx = 0 + for _ in range(len(batch)): + candidates_per_image.append(captions_all[idx:idx + self.num_captions]) + idx += self.num_captions + + blip_ms = now_millis() - start + + results = [] + for i in range(len(batch)): + results.append({ + "image": images[i], + "candidates": candidates_per_image[i], + "blip_ms": blip_ms, + }) + return results + + def get_metrics_namespace(self) -> str: + return "blip_captioning" + + +class ClipRankModelHandler(ModelHandler): + def __init__( + self, + model_name: str, + device: str, + batch_size: int, + score_normalize: bool): + self.model_name = model_name + self.device = device + self.batch_size = batch_size + self.score_normalize = score_normalize + + def load_model(self): + from transformers import CLIPModel, CLIPProcessor + processor = CLIPProcessor.from_pretrained(self.model_name) + model = CLIPModel.from_pretrained(self.model_name) + model.to(self.device) + model.eval() + return (model, processor) + + def batch_elements_kwargs(self): + return {"max_batch_size": self.batch_size} + + def run_inference( + self, batch: List[Dict[str, Any]], model_bundle, inference_args=None): + + model, processor = model_bundle + start_batch = now_millis() + + # Flat lists for a single batched CLIP forward pass + images: List[PILImage.Image] = [] + texts: List[str] = [] + offsets: List[Tuple[int, int]] = [] + # per element -> [start, end) in flat arrays + candidates_list: List[List[str]] = [] + blip_ms_list: List[Optional[int]] = [] + + for x in batch: + img = x["image"] + candidates = [str(c) for c in (x.get("candidates", []) or [])] + candidates_list.append(candidates) + blip_ms_list.append(x.get("blip_ms", None)) + + start_i = len(texts) + for c in candidates: + images.append(img) + texts.append(c) + end_i = len(texts) + offsets.append((start_i, end_i)) + + results: List[Dict[str, Any]] = [] + + # Fast path: no candidates at all + if not texts: + for blip_ms in blip_ms_list: + total_ms = int(blip_ms) if blip_ms is not None else None + results.append({ + "best_caption": "", + "best_score": None, + "candidates": [], + "scores": [], + "blip_ms": blip_ms, + "clip_ms": 0, + "total_ms": total_ms, + }) + return results + + with torch.no_grad(): + inputs = processor( + text=texts, + images=images, + return_tensors="pt", + padding=True, + truncation=True, + ) + inputs = { + k: (v.to(self.device) if torch.is_tensor(v) else v) + for k, v in inputs.items() + } + + # avoid NxN logits inside CLIPModel.forward() + img = model.get_image_features( + pixel_values=inputs["pixel_values"]) # [N, D] + txt = model.get_text_features( + input_ids=inputs["input_ids"], + attention_mask=inputs.get("attention_mask"), + ) # [N, D] + + img = img / img.norm(dim=-1, keepdim=True) + txt = txt / txt.norm(dim=-1, keepdim=True) + + logit_scale = model.logit_scale.exp() # scalar tensor + pair_scores = (img * txt).sum(dim=-1) * logit_scale # [N] + pair_scores_cpu = pair_scores.detach().cpu().tolist() + + batch_ms = now_millis() - start_batch + total_pairs = len(texts) + + items = zip(offsets, candidates_list, blip_ms_list) + for (start_i, end_i), candidates, blip_ms in items: + if start_i == end_i: + total_ms = int(blip_ms) if blip_ms is not None else None + results.append({ + "best_caption": "", + "best_score": None, + "candidates": [], + "scores": [], + "blip_ms": blip_ms, + "clip_ms": 0, + "total_ms": total_ms, + }) + continue + + scores = [float(pair_scores_cpu[j]) for j in range(start_i, end_i)] + + if self.score_normalize: + scores_t = torch.tensor(scores, dtype=torch.float32) + scores = torch.softmax(scores_t, dim=0).tolist() + + best_idx = max(range(len(scores)), key=lambda i, s=scores: s[i]) + + pairs = end_i - start_i + clip_ms_elem = int(batch_ms * (pairs / max(1, total_pairs))) + if pairs > 0: + clip_ms_elem = max(1, clip_ms_elem) + + total_ms = int(blip_ms) + clip_ms_elem if blip_ms is not None else None + results.append({ + "best_caption": candidates[best_idx], + "best_score": float(scores[best_idx]), + "candidates": candidates, + "scores": scores, + "blip_ms": blip_ms, + "clip_ms": clip_ms_elem, + "total_ms": total_ms, + }) + + return results + + def get_metrics_namespace(self) -> str: + return "clip_ranking" + + +# ============ Args & Helpers ============ + + +def parse_known_args(argv): + parser = argparse.ArgumentParser() + + # I/O & runtime + parser.add_argument( + '--mode', default='streaming', choices=['streaming', 'batch']) + parser.add_argument( + '--project', default='apache-beam-testing', help='GCP project ID') + parser.add_argument( + '--input', required=True, help='GCS path to file with image URIs') + parser.add_argument( + '--pubsub_topic', + default='projects/apache-beam-testing/topics/images_topic') + parser.add_argument( + '--pubsub_subscription', + default='projects/apache-beam-testing/subscriptions/images_subscription') + parser.add_argument( + '--output_table', + required=True, + help='BigQuery output table: dataset.table') + parser.add_argument( + '--publish_to_big_query', default='true', choices=['true', 'false']) + parser.add_argument( + '--feeder_start_delay_sec', + type=int, + default=900, + help=( + 'Delay before starting the feeder pipeline that reads URIs from GCS ' + 'and publishes them to Pub/Sub. This delay allows the main streaming ' + 'pipeline workers to start and scale before data ingestion begins.'), + ) + + # Device + parser.add_argument('--device', default='GPU', choices=['CPU', 'GPU']) + + # BLIP + parser.add_argument( + '--blip_model_name', default='Salesforce/blip-image-captioning-base') + parser.add_argument('--blip_batch_size', type=int, default=4) + parser.add_argument('--num_captions', type=int, default=5) + parser.add_argument('--max_new_tokens', type=int, default=30) + parser.add_argument('--num_beams', type=int, default=5) + + # CLIP + parser.add_argument( + '--clip_model_name', default='openai/clip-vit-base-patch32') + parser.add_argument('--clip_batch_size', type=int, default=8) + parser.add_argument( + '--clip_score_normalize', default='false', choices=['true', 'false']) + + # Windows + parser.add_argument('--window_sec', type=int, default=60) + parser.add_argument('--trigger_proc_time_sec', type=int, default=30) + + known_args, pipeline_args = parser.parse_known_args(argv) + return known_args, pipeline_args + + +def ensure_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + publisher.get_topic(request={"topic": full_topic_path}) + except NotFound: + publisher.create_topic(name=full_topic_path) + + try: + subscriber.get_subscription( + request={"subscription": full_subscription_path}) + except NotFound: + subscriber.create_subscription( + name=full_subscription_path, topic=full_topic_path) + + +def cleanup_pubsub_resources( + project: str, topic_path: str, subscription_path: str): + publisher = pubsub_v1.PublisherClient() + subscriber = pubsub_v1.SubscriberClient() + + topic_name = topic_path.split("/")[-1] + subscription_name = subscription_path.split("/")[-1] + + full_topic_path = publisher.topic_path(project, topic_name) + full_subscription_path = subscriber.subscription_path( + project, subscription_name) + + try: + subscriber.delete_subscription( + request={"subscription": full_subscription_path}) + logging.info(f"Deleted subscription: {subscription_name}") + except NotFound: + logging.info(f"Subscription already deleted: {subscription_name}") + + try: + publisher.delete_topic(request={"topic": full_topic_path}) + logging.info(f"Deleted topic: {topic_name}") + except NotFound: + logging.info(f"Topic already deleted: {topic_name}") + + +def override_or_add(args, flag, value): + if flag in args: + idx = args.index(flag) + args[idx + 1] = str(value) + else: + args.extend([flag, str(value)]) + + +# ============ Load pipeline ============ + + +def run_load_pipeline(known_args, pipeline_args): + """Reads GCS file with URIs and publishes them to Pub/Sub (for streaming).""" + # enforce smaller/CPU-only defaults for feeder + override_or_add(pipeline_args, '--device', 'CPU') + override_or_add(pipeline_args, '--num_workers', '5') + override_or_add(pipeline_args, '--max_num_workers', '10') + override_or_add( + pipeline_args, '--job_name', f"images-load-pubsub-{int(time.time())}") + override_or_add(pipeline_args, '--project', known_args.project) + pipeline_args = [ + arg for arg in pipeline_args if not arg.startswith("--experiments") + ] + + pipeline_options = PipelineOptions(pipeline_args) + pipeline = beam.Pipeline(options=pipeline_options) + + _ = ( + pipeline + | 'ReadGCSFile' >> beam.io.ReadFromText(known_args.input) + | 'FilterEmpty' >> beam.Filter(lambda line: line.strip()) + | 'ToBytes' >> beam.Map(lambda line: line.encode('utf-8')) + | 'ToPubSub' >> beam.io.WriteToPubSub(topic=known_args.pubsub_topic)) + return pipeline.run() + + +# ============ Main pipeline ============ + + +def run( + argv=None, save_main_session=True, test_pipeline=None) -> PipelineResult: + known_args, pipeline_args = parse_known_args(argv) + + if known_args.mode == 'streaming': + ensure_pubsub_resources( + project=known_args.project, + topic_path=known_args.pubsub_topic, + subscription_path=known_args.pubsub_subscription) + + # Start feeder thread that reads URIs from GCS and fills Pub/Sub. + # Delay is used to allow the main streaming pipeline workers to start + # and autoscale before the feeder pipeline begins publishing messages. + threading.Thread( + target=lambda: ( + time.sleep(known_args.feeder_start_delay_sec), run_load_pipeline( + known_args, pipeline_args)), + daemon=True).start() + + pipeline_options = PipelineOptions(pipeline_args) + pipeline_options.view_as(SetupOptions).save_main_session = save_main_session + pipeline_options.view_as(StandardOptions).streaming = ( + known_args.mode == 'streaming') + + device = 'cuda' if known_args.device.upper() == 'GPU' else 'cpu' + clip_score_normalize = (known_args.clip_score_normalize == 'true') + + blip_handler = BlipCaptionModelHandler( + model_name=known_args.blip_model_name, + device=device, + batch_size=int(known_args.blip_batch_size), + num_captions=int(known_args.num_captions), + max_new_tokens=int(known_args.max_new_tokens), + num_beams=int(known_args.num_beams), + ) + + clip_handler = ClipRankModelHandler( + model_name=known_args.clip_model_name, + device=device, + batch_size=int(known_args.clip_batch_size), + score_normalize=clip_score_normalize, + ) + + pipeline = test_pipeline or beam.Pipeline(options=pipeline_options) + + if known_args.mode == 'batch': + pcoll = ( + pipeline + | 'ReadURIsBatch' >> beam.io.ReadFromText(known_args.input) + | 'FilterEmptyBatch' >> beam.Filter(lambda s: s.strip())) + else: + pcoll = ( + pipeline + | 'ReadFromPubSub' >> + beam.io.ReadFromPubSub(subscription=known_args.pubsub_subscription) + | 'DecodeUTF8' >> beam.Map(lambda x: x.decode('utf-8')) + | 'Window' >> beam.WindowInto( + window.FixedWindows(known_args.window_sec), + trigger=beam.trigger.AfterProcessingTime( + known_args.trigger_proc_time_sec), + accumulation_mode=beam.trigger.AccumulationMode.DISCARDING, + allowed_lateness=0)) + + keyed = (pcoll | 'MakeKey' >> beam.ParDo(MakeKeyDoFn())) + image_bytes = (keyed | 'ReadImageBytes' >> beam.ParDo(ReadImageBytesDoFn())) + images = (image_bytes | 'DecodeImage' >> beam.ParDo(DecodeImageDoFn())) + + # Stage 1: BLIP candidate generation + blip_out = ( + images + | 'RunInferenceBLIP' >> RunInference(KeyedModelHandler(blip_handler))) + + # Stage 2: CLIP ranking over candidates + clip_out = ( + blip_out + | 'RunInferenceCLIP' >> RunInference(KeyedModelHandler(clip_handler))) + + results = ( + clip_out + | 'PostProcess' >> beam.ParDo( + PostProcessDoFn( + blip_name=known_args.blip_model_name, + clip_name=known_args.clip_model_name))) + + method = ( + beam.io.WriteToBigQuery.Method.FILE_LOADS if known_args.mode == 'batch' + else beam.io.WriteToBigQuery.Method.STREAMING_INSERTS) + + if known_args.publish_to_big_query == 'true': + _ = ( + results + | 'WriteToBigQuery' >> beam.io.WriteToBigQuery( + known_args.output_table, + schema=( + 'image_id:STRING, blip_model:STRING, clip_model:STRING, ' + 'best_caption:STRING, best_score:FLOAT, ' + 'candidates:STRING, scores:STRING, ' + 'blip_ms:INT64, clip_ms:INT64, total_ms:INT64, infer_ms:INT64'), + write_disposition=beam.io.BigQueryDisposition.WRITE_APPEND, + create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED, + method=method)) + + result = pipeline.run() + result.wait_until_finish(duration=1800000) # 30 min + try: + result.cancel() + except Exception: + pass + + if known_args.mode == 'streaming': + cleanup_pubsub_resources( + project=known_args.project, + topic_path=known_args.pubsub_topic, + subscription_path=known_args.pubsub_subscription) + Review Comment:  ### Robustness: Ensure Pub/Sub Resources are Cleaned Up If the pipeline fails or is cancelled due to a timeout, the cleanup code at the end of the `run` function might not be reached, leaving orphaned Pub/Sub topics and subscriptions. Wrap the pipeline execution and waiting in a `try...finally` block to guarantee cleanup. ```suggestion try: result = pipeline.run() result.wait_until_finish(duration=1800000) # 30 min finally: try: result.cancel() except Exception: pass if known_args.mode == 'streaming': cleanup_pubsub_resources( project=known_args.project, topic_path=known_args.pubsub_topic, subscription_path=known_args.pubsub_subscription) ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
