This is an automated email from the ASF dual-hosted git repository.

vterentev pushed a commit to branch oss-image-cpu
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/oss-image-cpu by this push:
     new 4c9eee02b97 Fix logits
4c9eee02b97 is described below

commit 4c9eee02b97c70d1365dd423ceb89c5b8e675b42
Author: Vitaly Terentyev <[email protected]>
AuthorDate: Thu Jan 29 16:54:12 2026 +0400

    Fix logits
---
 .../inference/pytorch_imagenet_rightfit.py         | 25 +++++++++++++++++++---
 1 file changed, 22 insertions(+), 3 deletions(-)

diff --git 
a/sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py 
b/sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py
index 79e553c0652..c4be970829d 100644
--- a/sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py
+++ b/sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py
@@ -176,9 +176,28 @@ class PostProcessDoFn(beam.DoFn):
 
   def process(self, kv: Tuple[str, PredictionResult]):
     image_id, pred = kv
-    logits = pred.inference[
-        "logits"]  # torch.Tensor [B, num_classes] or [num_classes]
-    if isinstance(logits, torch.Tensor) and logits.ndim == 1:
+
+    # pred can be PredictionResult OR raw inference object.
+    inference_obj = pred.inference if hasattr(pred, "inference") else pred
+
+    # inference_obj can be dict {'logits': tensor} OR tensor directly.
+    if isinstance(inference_obj, dict):
+      logits = inference_obj.get("logits", None)
+      if logits is None:
+        # fallback: try first value if dict shape differs
+        try:
+          logits = next(iter(inference_obj.values()))
+        except Exception:
+          logits = None
+    else:
+      logits = inference_obj
+
+    if not isinstance(logits, torch.Tensor):
+      logging.warning("Unexpected logits type for %s: %s", image_id, 
type(logits))
+      return
+
+    # Ensure shape [1, C]
+    if logits.ndim == 1:
       logits = logits.unsqueeze(0)
 
     probs = F.softmax(logits, dim=-1)  # [B, C]

Reply via email to