This is an automated email from the ASF dual-hosted git repository.
vterentev pushed a commit to branch oss-image-cpu
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/oss-image-cpu by this push:
new 4c9eee02b97 Fix logits
4c9eee02b97 is described below
commit 4c9eee02b97c70d1365dd423ceb89c5b8e675b42
Author: Vitaly Terentyev <[email protected]>
AuthorDate: Thu Jan 29 16:54:12 2026 +0400
Fix logits
---
.../inference/pytorch_imagenet_rightfit.py | 25 +++++++++++++++++++---
1 file changed, 22 insertions(+), 3 deletions(-)
diff --git
a/sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py
b/sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py
index 79e553c0652..c4be970829d 100644
--- a/sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py
+++ b/sdks/python/apache_beam/examples/inference/pytorch_imagenet_rightfit.py
@@ -176,9 +176,28 @@ class PostProcessDoFn(beam.DoFn):
def process(self, kv: Tuple[str, PredictionResult]):
image_id, pred = kv
- logits = pred.inference[
- "logits"] # torch.Tensor [B, num_classes] or [num_classes]
- if isinstance(logits, torch.Tensor) and logits.ndim == 1:
+
+ # pred can be PredictionResult OR raw inference object.
+ inference_obj = pred.inference if hasattr(pred, "inference") else pred
+
+ # inference_obj can be dict {'logits': tensor} OR tensor directly.
+ if isinstance(inference_obj, dict):
+ logits = inference_obj.get("logits", None)
+ if logits is None:
+ # fallback: try first value if dict shape differs
+ try:
+ logits = next(iter(inference_obj.values()))
+ except Exception:
+ logits = None
+ else:
+ logits = inference_obj
+
+ if not isinstance(logits, torch.Tensor):
+ logging.warning("Unexpected logits type for %s: %s", image_id,
type(logits))
+ return
+
+ # Ensure shape [1, C]
+ if logits.ndim == 1:
logits = logits.unsqueeze(0)
probs = F.softmax(logits, dim=-1) # [B, C]