This is an automated email from the ASF dual-hosted git repository.
Amar3tto pushed a commit to branch test-inference
in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/test-inference by this push:
new 4d2cfc68972 Fix config
4d2cfc68972 is described below
commit 4d2cfc689727fbebcb5f221a4e5c674dc39a9591
Author: Vitaly Terentyev <[email protected]>
AuthorDate: Wed May 6 23:02:00 2026 +0400
Fix config
---
sdks/python/apache_beam/examples/inference/pytorch_sentiment.py | 8 +++++++-
1 file changed, 7 insertions(+), 1 deletion(-)
diff --git a/sdks/python/apache_beam/examples/inference/pytorch_sentiment.py
b/sdks/python/apache_beam/examples/inference/pytorch_sentiment.py
index 3bb36930a04..71669522674 100644
--- a/sdks/python/apache_beam/examples/inference/pytorch_sentiment.py
+++ b/sdks/python/apache_beam/examples/inference/pytorch_sentiment.py
@@ -234,9 +234,15 @@ def run(
method = beam.io.WriteToBigQuery.Method.STREAMING_INSERTS
pipeline_options.view_as(StandardOptions).streaming = True
+ model_config = DistilBertConfig.from_pretrained(
+ known_args.model_path, num_labels=2)
+ # Some transformers versions may not initialize this field on config objects.
+ if not hasattr(model_config, 'pruned_heads'):
+ model_config.pruned_heads = {}
+
model_handler = PytorchModelHandlerKeyedTensor(
model_class=DistilBertForSequenceClassification,
- model_params={'config': DistilBertConfig(num_labels=2)},
+ model_params={'config': model_config},
state_dict_path=known_args.model_state_dict_path,
device='GPU')