This is an automated email from the ASF dual-hosted git repository.
vterentev pushed a commit to branch oss-image-cpu
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/oss-image-cpu by this push:
new 9dd2ef927bc Fix inference_fn
9dd2ef927bc is described below
commit 9dd2ef927bc6a83f77f05877278bc5d6aab2fb53
Author: Vitaly Terentyev <[email protected]>
AuthorDate: Thu Jan 22 13:06:07 2026 +0400
Fix inference_fn
---
.../inference/pytorch_image_object_detection.py | 27 ++++++++++++++++------
1 file changed, 20 insertions(+), 7 deletions(-)
diff --git
a/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py
b/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py
index 0bb5f53e5b0..cb66e0f9dba 100644
---
a/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py
+++
b/sdks/python/apache_beam/examples/inference/pytorch_image_object_detection.py
@@ -31,12 +31,13 @@ import io
import json
import logging
import time
-from typing import Iterable
-from typing import Optional
-from typing import Tuple
from typing import Any
from typing import Dict
+from typing import Iterable
from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
import apache_beam as beam
from apache_beam.io.filesystems import FileSystems
@@ -135,12 +136,24 @@ class DecodePreprocessDoFn(beam.DoFn):
def _torchvision_detection_inference_fn(
- model, batch: List[torch.Tensor], device: str) -> List[Dict[str, Any]]:
- """Custom inference for TorchVision detection models.
-
- TorchVision detection models expect: List[Tensor] (each: CHW float [0..1]).
+ batch: Sequence[torch.Tensor],
+ model: torch.nn.Module,
+ device: torch.device,
+ inference_args: Optional[dict[str, Any]] = None,
+ model_id: Optional[str] = None,
+) -> List[Dict[str, Any]]:
+ """Inference function for TorchVision detection models.
+
+ TorchVision detection models expect List[Tensor] where each tensor is:
+ - shape: [3, H, W]
+ - dtype: float32
+ - values: [0..1]
"""
+ del inference_args
+ del model_id
+
with torch.no_grad():
+ # Ensure tensors are on device
inputs = []
for t in batch:
if isinstance(t, torch.Tensor):