ryanthompson591 commented on code in PR #21806:
URL: https://github.com/apache/beam/pull/21806#discussion_r897131988
##########
sdks/python/apache_beam/ml/inference/pytorch_inference_test.py:
##########
@@ -237,10 +237,8 @@ def test_run_inference_kwargs_prediction_params(self):
inference_runner = TestPytorchModelHandlerForInferenceOnly(
torch.device('cpu'))
predictions = inference_runner.run_inference(
- batch=KWARGS_TORCH_EXAMPLES,
- model=model,
- prediction_params=prediction_params)
- for actual, expected in zip(predictions, KWARGS_TORCH_PREDICTIONS):
+ batch=KEYED_TORCH_EXAMPLES, model=model, extra_kwargs=extra_kwargs)
Review Comment:
ok I've been thinking about this:
I was playing around with ways to make arguments specific to run_inference
and I think there are only three ways. Either what you have done, anonymous
args, or an if statement
if inference_args:
model_handler.run_inference(model, batch, inference_args)
else:
model_handler.run_inference(model, batch)
I'm not sure what I prefer now that I'm looking at it.
The if statement has the advantage of allowing clients that don't expect
this argument to fail or pass without modifcations.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]