damccorm opened a new issue, #37093: URL: https://github.com/apache/beam/issues/37093
### What happened? I noticed that while vllm [expects inference args to be available to it](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/vllm_inference.py#L217), in practice this is not possible because it doesn't override the base ModelHandler `validate_inference_args function`, which doesn't allow any inference args. See: https://github.com/apache/beam/blob/master/sdks/python/apache_beam/ml/inference/base.py#L215 Given this, I did an audit of all our model handlers and found that most are not handling this correctly. Here are the results: **Currently handling inference args correctly by overriding validate_inference_args:** - vertex_ai_inference.py - passed through [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py#L206) among other places. Never dropped silently. validate_inference_args [is overriden](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/vertex_ai_inference.py#L210) to silently pass - pytorch_inference.py - passed through [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/tensorflow_inference.py#L84) among other places. Never dropped silently. validate_inference_args [is overriden](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/tensorflow_inference.py#L363) to silently pass in multiple places - pytorch_inference.py - passed through [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/pytorch_inference.py#L152) among other places. Never dropped silently. validate_inference_args [is overriden](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/pytorch_inference.py#L345) to silently pass in multiple places **Currently trying to consume inference args, but not overriding validate_inference_args:** - gemini_inference.py - passed through [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/gemini_inference.py#L97) and [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/gemini_inference.py#L74). Never dropped silently. validate_inference_args not overridden - huggingface_inference.py - passed through [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/huggingface_inference.py#L186) and [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/huggingface_inference.py#L206) among other places. Never dropped silently. validate_inference_args not overridden - onnx_inference.py - passed through [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/onnx_inference.py#L152). Intentionally added to [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/onnx_inference.py#L48). Never dropped silently. validate_inference_args not overridden - vllm_inference.py - passed through [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/vllm_inference.py#L217) among other places. Never dropped silently. validate_inference_args not overridden - xgboost_inference.py - passed through [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/xgboost_inference.py#L66) among other places. Never dropped silently. validate_inference_args not overridden **Currently not handling inference args, but should:** - sklearn_inference.py - passed through [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/sklearn_inference.py#L310) among other places. Silently dropped [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/sklearn_inference.py#L75) and [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/sklearn_inference.py#L203) (probably wrongly) validate_inference_args not overridden. **Currently not handling inference args correctly:** - tensorrt_inference.py - silently dropped [here](https://github.com/apache/beam/blob/576a80232f8edca5e732eea97b90482209c8a460/sdks/python/apache_beam/ml/inference/tensorrt_inference.py#L183). This may be correct, inference args are hard to map to meaning here. ------------------------- The underlying rationale for disallowing inference args (`Because most frameworks do not need extra arguments in their predict() call, the default behavior is to error out if inference_args are present.`) clearly does not actually hold. Given that, we should: 1) Change the base ModelHandler behavior to allow inference args 2) Add the validation to model handlers that need it. At this time only tensorrt_inference.py needs it, and sklearn.py needs to be updated to correctly consume the inference args. 3) Mention this in CHANGES.md since it is a behavior change (though not breaking) ### Issue Priority Priority: 2 (default / most bugs should be filed as P2) ### Issue Components - [x] Component: Python SDK - [ ] Component: Java SDK - [ ] Component: Go SDK - [ ] Component: Typescript SDK - [ ] Component: IO connector - [ ] Component: Beam YAML - [ ] Component: Beam examples - [ ] Component: Beam playground - [ ] Component: Beam katas - [ ] Component: Website - [ ] Component: Infrastructure - [ ] Component: Spark Runner - [ ] Component: Flink Runner - [ ] Component: Samza Runner - [ ] Component: Twister2 Runner - [ ] Component: Hazelcast Jet Runner - [ ] Component: Google Cloud Dataflow Runner -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
