Amar3tto commented on code in PR #37186:
URL: https://github.com/apache/beam/pull/37186#discussion_r3201880541


##########
sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py:
##########
@@ -0,0 +1,690 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This pipeline performs image captioning using a multi-model approach:
+BLIP generates candidate captions, CLIP ranks them by image-text similarity.
+
+The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP
+caption generation in batches on GPU, then runs CLIP ranking in batches on GPU.
+Results are written to BigQuery.
+"""
+
+import argparse
+import io
+import json
+import logging
+import threading
+import time
+from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import apache_beam as beam
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.options.pipeline_options import StandardOptions
+from apache_beam.runners.runner import PipelineResult
+from apache_beam.transforms import window
+
+from google.cloud import pubsub_v1
+import torch
+import PIL.Image as PILImage
+
+# ============ Utility ============
+
+
+def now_millis() -> int:
+  return int(time.time() * 1000)
+
+
+def read_gcs_file_lines(gcs_path: str) -> Iterable[str]:
+  """Reads text lines from a GCS file."""
+  with FileSystems.open(gcs_path) as f:
+    for line in f.read().decode("utf-8").splitlines():
+      yield line.strip()
+
+
+def load_image_from_uri(uri: str) -> bytes:
+  with FileSystems.open(uri) as f:
+    return f.read()
+
+
+def sha1_hex(s: str) -> str:
+  import hashlib
+  return hashlib.sha1(s.encode("utf-8")).hexdigest()
+
+
+def decode_pil(image_bytes: bytes) -> PILImage.Image:
+  with PILImage.open(io.BytesIO(image_bytes)) as img:
+    img = img.convert("RGB")
+    img.load()
+    return img
+
+
+# ============ DoFns ============
+
+
+class RateLimitDoFn(beam.DoFn):
+  def __init__(self, rate_per_sec: float):
+    self.delay = 1.0 / rate_per_sec
+
+  def process(self, element):
+    time.sleep(self.delay)
+    yield element
+
+
+class MakeKeyDoFn(beam.DoFn):
+  """Produce (image_id, uri) where image_id is stable for dedup and keys."""
+  def process(self, element: str):
+    uri = element
+    image_id = sha1_hex(uri)
+    yield image_id, uri
+
+
+class ReadImageBytesDoFn(beam.DoFn):
+  """Turn (image_id, uri) -> (image_id, dict(image_bytes, uri))."""
+  def process(self, kv: Tuple[str, str]):
+    image_id, uri = kv
+    try:
+      b = load_image_from_uri(uri)
+      yield image_id, {"image_bytes": b, "uri": uri}
+    except Exception as e:
+      logging.warning("Failed to read image %s (%s): %s", image_id, uri, e)
+      return
+
+
+class PostProcessDoFn(beam.DoFn):
+  """Final PredictionResult -> row for BigQuery."""
+  def __init__(self, blip_name: str, clip_name: str):
+    self.blip_name = blip_name
+    self.clip_name = clip_name
+
+  def process(self, kv: Tuple[str, PredictionResult]):
+    image_id, pred = kv
+    if hasattr(pred, "inference"):
+      inf = pred.inference or {}
+    else:
+      inf = pred
+    # Expected inference fields from CLIP handler:
+    # best_caption, best_score, candidates, scores, blip_ms, clip_ms, total_ms
+    best_caption = inf.get("best_caption", "")
+    best_score = inf.get("best_score", None)
+    candidates = inf.get("candidates", [])
+    scores = inf.get("scores", [])
+    blip_ms = inf.get("blip_ms", None)
+    clip_ms = inf.get("clip_ms", None)
+    total_ms = inf.get("total_ms", None)
+
+    yield {
+        "image_id": image_id,
+        "blip_model": self.blip_name,
+        "clip_model": self.clip_name,
+        "best_caption": best_caption,
+        "best_score": float(best_score) if best_score is not None else None,
+        "candidates": json.dumps(candidates),
+        "scores": json.dumps(scores),
+        "blip_ms": int(blip_ms) if blip_ms is not None else None,
+        "clip_ms": int(clip_ms) if clip_ms is not None else None,
+        "total_ms": int(total_ms) if total_ms is not None else None,
+        "infer_ms": now_millis(),
+    }
+
+
+# ============ Model Handlers ============
+
+
+class BlipCaptionModelHandler(ModelHandler):
+  def __init__(
+      self,
+      model_name: str,
+      device: str,
+      batch_size: int,
+      num_captions: int,
+      max_new_tokens: int,
+      num_beams: int):
+    self.model_name = model_name
+    self.device = device
+    self.batch_size = batch_size
+    self.num_captions = num_captions
+    self.max_new_tokens = max_new_tokens
+    self.num_beams = num_beams
+
+    self._model = None
+    self._processor = None
+
+  def load_model(self):
+    from transformers import BlipForConditionalGeneration, BlipProcessor
+    self._processor = BlipProcessor.from_pretrained(self.model_name)
+    self._model = BlipForConditionalGeneration.from_pretrained(self.model_name)
+    self._model.eval()
+    self._model.to(self.device)
+    return self._model
+
+  def batch_elements_kwargs(self):
+    return {"max_batch_size": self.batch_size}
+
+  def run_inference(
+      self, batch: List[Dict[str, Any]], model, inference_args=None):
+
+    if model is not None:

Review Comment:
   Fixed



##########
sdks/python/apache_beam/examples/inference/pytorch_image_captioning.py:
##########
@@ -0,0 +1,690 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""This pipeline performs image captioning using a multi-model approach:
+BLIP generates candidate captions, CLIP ranks them by image-text similarity.
+
+The pipeline reads image URIs from a GCS input file, decodes images, runs BLIP
+caption generation in batches on GPU, then runs CLIP ranking in batches on GPU.
+Results are written to BigQuery.
+"""
+
+import argparse
+import io
+import json
+import logging
+import threading
+import time
+from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import apache_beam as beam
+from apache_beam.io.filesystems import FileSystems
+from apache_beam.ml.inference.base import KeyedModelHandler
+from apache_beam.ml.inference.base import ModelHandler
+from apache_beam.ml.inference.base import PredictionResult
+from apache_beam.ml.inference.base import RunInference
+from apache_beam.options.pipeline_options import PipelineOptions
+from apache_beam.options.pipeline_options import SetupOptions
+from apache_beam.options.pipeline_options import StandardOptions
+from apache_beam.runners.runner import PipelineResult
+from apache_beam.transforms import window
+
+from google.cloud import pubsub_v1
+import torch
+import PIL.Image as PILImage
+
+# ============ Utility ============
+
+
+def now_millis() -> int:
+  return int(time.time() * 1000)
+
+
+def read_gcs_file_lines(gcs_path: str) -> Iterable[str]:
+  """Reads text lines from a GCS file."""
+  with FileSystems.open(gcs_path) as f:
+    for line in f.read().decode("utf-8").splitlines():
+      yield line.strip()
+
+
+def load_image_from_uri(uri: str) -> bytes:
+  with FileSystems.open(uri) as f:
+    return f.read()
+
+
+def sha1_hex(s: str) -> str:
+  import hashlib
+  return hashlib.sha1(s.encode("utf-8")).hexdigest()
+
+
+def decode_pil(image_bytes: bytes) -> PILImage.Image:
+  with PILImage.open(io.BytesIO(image_bytes)) as img:
+    img = img.convert("RGB")
+    img.load()
+    return img
+
+
+# ============ DoFns ============
+
+
+class RateLimitDoFn(beam.DoFn):
+  def __init__(self, rate_per_sec: float):
+    self.delay = 1.0 / rate_per_sec
+
+  def process(self, element):
+    time.sleep(self.delay)
+    yield element
+
+
+class MakeKeyDoFn(beam.DoFn):
+  """Produce (image_id, uri) where image_id is stable for dedup and keys."""
+  def process(self, element: str):
+    uri = element
+    image_id = sha1_hex(uri)
+    yield image_id, uri
+
+
+class ReadImageBytesDoFn(beam.DoFn):
+  """Turn (image_id, uri) -> (image_id, dict(image_bytes, uri))."""
+  def process(self, kv: Tuple[str, str]):
+    image_id, uri = kv
+    try:
+      b = load_image_from_uri(uri)
+      yield image_id, {"image_bytes": b, "uri": uri}
+    except Exception as e:
+      logging.warning("Failed to read image %s (%s): %s", image_id, uri, e)
+      return
+
+
+class PostProcessDoFn(beam.DoFn):
+  """Final PredictionResult -> row for BigQuery."""
+  def __init__(self, blip_name: str, clip_name: str):
+    self.blip_name = blip_name
+    self.clip_name = clip_name
+
+  def process(self, kv: Tuple[str, PredictionResult]):
+    image_id, pred = kv
+    if hasattr(pred, "inference"):
+      inf = pred.inference or {}
+    else:
+      inf = pred
+    # Expected inference fields from CLIP handler:
+    # best_caption, best_score, candidates, scores, blip_ms, clip_ms, total_ms
+    best_caption = inf.get("best_caption", "")
+    best_score = inf.get("best_score", None)
+    candidates = inf.get("candidates", [])
+    scores = inf.get("scores", [])
+    blip_ms = inf.get("blip_ms", None)
+    clip_ms = inf.get("clip_ms", None)
+    total_ms = inf.get("total_ms", None)
+
+    yield {
+        "image_id": image_id,
+        "blip_model": self.blip_name,
+        "clip_model": self.clip_name,
+        "best_caption": best_caption,
+        "best_score": float(best_score) if best_score is not None else None,
+        "candidates": json.dumps(candidates),
+        "scores": json.dumps(scores),
+        "blip_ms": int(blip_ms) if blip_ms is not None else None,
+        "clip_ms": int(clip_ms) if clip_ms is not None else None,
+        "total_ms": int(total_ms) if total_ms is not None else None,
+        "infer_ms": now_millis(),
+    }
+
+
+# ============ Model Handlers ============
+
+
+class BlipCaptionModelHandler(ModelHandler):
+  def __init__(
+      self,
+      model_name: str,
+      device: str,
+      batch_size: int,
+      num_captions: int,
+      max_new_tokens: int,
+      num_beams: int):
+    self.model_name = model_name
+    self.device = device
+    self.batch_size = batch_size
+    self.num_captions = num_captions
+    self.max_new_tokens = max_new_tokens
+    self.num_beams = num_beams
+
+    self._model = None
+    self._processor = None
+
+  def load_model(self):
+    from transformers import BlipForConditionalGeneration, BlipProcessor
+    self._processor = BlipProcessor.from_pretrained(self.model_name)
+    self._model = BlipForConditionalGeneration.from_pretrained(self.model_name)
+    self._model.eval()
+    self._model.to(self.device)
+    return self._model
+
+  def batch_elements_kwargs(self):
+    return {"max_batch_size": self.batch_size}
+
+  def run_inference(
+      self, batch: List[Dict[str, Any]], model, inference_args=None):
+
+    if model is not None:
+      self._model = model
+      self._model.to(self.device)
+      self._model.eval()
+    if self._processor is None:
+      from transformers import BlipProcessor
+      self._processor = BlipProcessor.from_pretrained(self.model_name)

Review Comment:
   Fixed



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to