damccorm opened a new issue, #37093:
URL: https://github.com/apache/beam/issues/37093

   ### What happened?
   
   I noticed that while vllm [expects inference args to be available to 
it](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/vllm_inference.py#L217),
 in practice this is not possible because it doesn't override the base 
ModelHandler `validate_inference_args function`, which doesn't allow any 
inference args. See:
   
   
https://github.com/apache/beam/blob/master/sdks/python/apache_beam/ml/inference/base.py#L215
   
   Given this, I did an audit of all our model handlers and found that most are 
not handling this correctly. Here are the results:
   
   **Currently handling inference args correctly by overriding 
validate_inference_args:**
   
   - vertex_ai_inference.py - passed through 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py#L206)
 among other places. Never dropped silently. validate_inference_args [is 
overriden](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py#L210)
 to silently pass
   - pytorch_inference.py - passed through 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/tensorflow_inference.py#L84)
 among other places. Never dropped silently. validate_inference_args [is 
overriden](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/tensorflow_inference.py#L363)
 to silently pass in multiple places
   - pytorch_inference.py - passed through 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/pytorch_inference.py#L152)
 among other places. Never dropped silently. validate_inference_args [is 
overriden](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/pytorch_inference.py#L345)
 to silently pass in multiple places
   
   **Currently trying to consume inference args, but not overriding 
validate_inference_args:**
   
   - gemini_inference.py - passed through 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/gemini_inference.py#L97)
 and 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/gemini_inference.py#L74).
 Never dropped silently. validate_inference_args not overridden
   - huggingface_inference.py - passed through 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/huggingface_inference.py#L186)
 and 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/huggingface_inference.py#L206)
 among other places. Never dropped silently. validate_inference_args not 
overridden
   - onnx_inference.py - passed through 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/onnx_inference.py#L152).
 Intentionally added to 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/onnx_inference.py#L48).
 Never dropped silently. validate_inference_args not overridden
   - vllm_inference.py - passed through 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/vllm_inference.py#L217)
 among other places. Never dropped silently. validate_inference_args not 
overridden
   - xgboost_inference.py - passed through 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/xgboost_inference.py#L66)
 among other places. Never dropped silently. validate_inference_args not 
overridden
   
   **Currently not handling inference args, but should:**
   
   - sklearn_inference.py - passed through 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/sklearn_inference.py#L310)
 among other places. Silently dropped 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/sklearn_inference.py#L75)
 and 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/sklearn_inference.py#L203)
 (probably wrongly) validate_inference_args not overridden.
   
   **Currently not handling inference args correctly:**
   
   - tensorrt_inference.py - silently dropped 
[here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/tensorrt_inference.py#L183).
 This may be correct, inference args are hard to map to meaning here.
   
   
   -------------------------
   
   The underlying rationale for disallowing inference args (`Because most 
frameworks do not need extra arguments in their predict() call, the default 
behavior is to error out if inference_args are present.`) clearly does not 
actually hold. Given that, we should:
   
   1) Change the base ModelHandler behavior to allow inference args
   2) Add the validation to model handlers that need it. At this time only 
tensorrt_inference.py needs it, and sklearn.py needs to be updated to correctly 
consume the inference args.
   3) Mention this in CHANGES.md since it is a behavior change (though not 
breaking)
   
   ### Issue Priority
   
   Priority: 2 (default / most bugs should be filed as P2)
   
   ### Issue Components
   
   - [x] Component: Python SDK
   - [ ] Component: Java SDK
   - [ ] Component: Go SDK
   - [ ] Component: Typescript SDK
   - [ ] Component: IO connector
   - [ ] Component: Beam YAML
   - [ ] Component: Beam examples
   - [ ] Component: Beam playground
   - [ ] Component: Beam katas
   - [ ] Component: Website
   - [ ] Component: Infrastructure
   - [ ] Component: Spark Runner
   - [ ] Component: Flink Runner
   - [ ] Component: Samza Runner
   - [ ] Component: Twister2 Runner
   - [ ] Component: Hazelcast Jet Runner
   - [ ] Component: Google Cloud Dataflow Runner


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to