Hi, I copy and paste 4 script. The RNN is defined in 'rnn.py'. The rnnt is
defined in 'model_rnnt.py'.
I tried to import the rnnt model into TVM . Please check it in first script.
In the function get_rnnt_model I load the pre-trained model RNNT.
In the function rnnt_model_to_tvm_mod , I tried to transform it into TVM.
```
def get_rnnt_model(featurizer_config, model_definition, ctc_vocab, ckpt):
model = RNNT(
feature_config=featurizer_config,
rnnt=model_definition['rnnt'],
num_classes=len(ctc_vocab)
)
checkpoint = torch.load(ckpt, map_location="cpu")
model.load_state_dict(checkpoint['state_dict'], strict=False)
model.eval()
return model
def rnnt_model_to_tvm_mod(model):
input_shape = (316, 1, 240)
len_shape = (316)
t_audio_signal_e = torch.randn(input_shape)
t_a_sig_length_e = torch.randn(len_shape)
model.encoder = torch.jit.trace(model.encoder, (t_audio_signal_e,
t_a_sig_length_e)).eval()
mod, params = relay.frontend.from_pytorch(model.encoder, input_shapes=None)
mod = relay.transform.RemoveUnusedFunctions()(mod)
return mod, params
```
---
[Visit
Topic](https://discuss.tvm.apache.org/t/import-rnn-t-pytorch-model-into-tvm/7874/9)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/4eee8e043d12bb2734a67633e8328268b4da7715f22d46f14dd122ea21c02abc).