model_rnnt.py
```
# Copyright (c) 2019, Myrtle Software Limited. All rights reserved.
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from rnn import rnn
from rnn import StackTime
class RNNT(torch.nn.Module):
"""A Recurrent Neural Network Transducer (RNN-T).
Args:
in_features: Number of input features per step per batch.
vocab_size: Number of output symbols (inc blank).
forget_gate_bias: Total initialized value of the bias used in the
forget gate. Set to None to use PyTorch's default initialisation.
(See: http://proceedings.mlr.press/v37/jozefowicz15.pdf)
batch_norm: Use batch normalization in encoder and prediction network
if true.
encoder_n_hidden: Internal hidden unit size of the encoder.
encoder_rnn_layers: Encoder number of layers.
pred_n_hidden: Internal hidden unit size of the prediction network.
pred_rnn_layers: Prediction network number of layers.
joint_n_hidden: Internal hidden unit size of the joint network.
rnn_type: string. Type of rnn in SUPPORTED_RNNS.
"""
def __init__(self, rnnt=None, num_classes=1, **kwargs):
super().__init__()
if kwargs.get("no_featurizer", False):
in_features = kwargs.get("in_features")
else:
feat_config = kwargs.get("feature_config")
# This may be useful in the future, for MLPerf
# configuration.
in_features = feat_config['features'] * \
feat_config.get("frame_splicing", 1)
self._pred_n_hidden = rnnt['pred_n_hidden']
self.encoder_n_hidden = rnnt["encoder_n_hidden"]
self.encoder_pre_rnn_layers = rnnt["encoder_pre_rnn_layers"]
self.encoder_post_rnn_layers = rnnt["encoder_post_rnn_layers"]
self.pred_n_hidden = rnnt["pred_n_hidden"]
self.pred_rnn_layers = rnnt["pred_rnn_layers"]
self.encoder = Encoder(in_features,
rnnt["encoder_n_hidden"],
rnnt["encoder_pre_rnn_layers"],
rnnt["encoder_post_rnn_layers"],
rnnt["forget_gate_bias"],
None if "norm" not in rnnt else rnnt["norm"],
rnnt["rnn_type"],
rnnt["encoder_stack_time_factor"],
rnnt["dropout"],
)
self.prediction = self._predict(
num_classes,
rnnt["pred_n_hidden"],
rnnt["pred_rnn_layers"],
rnnt["forget_gate_bias"],
None if "norm" not in rnnt else rnnt["norm"],
rnnt["rnn_type"],
rnnt["dropout"],
)
self.joint_net = self._joint_net(
num_classes,
rnnt["pred_n_hidden"],
rnnt["encoder_n_hidden"],
rnnt["joint_n_hidden"],
rnnt["dropout"],
)
def _predict(self, vocab_size, pred_n_hidden, pred_rnn_layers,
forget_gate_bias, norm, rnn_type, dropout):
layers = torch.nn.ModuleDict({
"embed": torch.nn.Embedding(vocab_size - 1, pred_n_hidden),
"dec_rnn": rnn(
rnn=rnn_type,
input_size=pred_n_hidden,
hidden_size=pred_n_hidden,
num_layers=pred_rnn_layers,
norm=norm,
forget_gate_bias=forget_gate_bias,
dropout=dropout,
),
})
return layers
def _joint_net(self, vocab_size, pred_n_hidden, enc_n_hidden,
joint_n_hidden, dropout):
layers = [
torch.nn.Linear(pred_n_hidden + enc_n_hidden, joint_n_hidden),
torch.nn.ReLU(),
] + ([torch.nn.Dropout(p=dropout), ] if dropout else []) + [
torch.nn.Linear(joint_n_hidden, vocab_size)
]
return torch.nn.Sequential(
*layers
)
# Perhaps what I really need to do is provide a value for
# state. But why can't I just specify a type for abstract
# intepretation? That's what I really want!
# We really want two "states" here...
def forward(self, batch, state=None):
# batch: ((x, y), (x_lens, y_lens))
raise RuntimeError(
"RNNT::forward is not currently used. "
"It corresponds to training, where your entire output sequence "
"is known before hand.")
# x: TxBxF
(x, y_packed), (x_lens, y_lens) = batch
x_packed = torch.nn.utils.rnn.pack_padded_sequence(x, x_lens)
f, x_lens = self.encode(x_packed)
g, _ = self.predict(y_packed, state)
out = self.joint(f, g)
return out, (x_lens, y_lens)
def predict(self, y, state=None, add_sos=True):
"""
B - batch size
U - label length
H - Hidden dimension size
L - Number of decoder layers = 2
Args:
y: (B, U)
Returns:
Tuple (g, hid) where:
g: (B, U + 1, H)
hid: (h, c) where h is the final sequence hidden state and c is
the final cell state:
h (tensor), shape (L, B, H)
c (tensor), shape (L, B, H)
"""
if isinstance(y, torch.Tensor):
y = self.prediction["embed"](y)
elif isinstance(y, torch.nn.utils.rnn.PackedSequence):
# Teacher-forced training mode
# (B, U) -> (B, U, H)
y._replace(data=self.prediction["embed"](y.data))
else:
# inference mode
B = 1 if state is None else state[0].size(1)
y = torch.zeros((B, 1, self.pred_n_hidden)).to(
device=self.joint_net[0].weight.device,
dtype=self.joint_net[0].weight.dtype
)
# preprend blank "start of sequence" symbol
if add_sos:
B, U, H = y.shape
start = torch.zeros((B, 1, H)).to(device=y.device, dtype=y.dtype)
y = torch.cat([start, y], dim=1).contiguous() # (B, U + 1, H)
else:
start = None # makes del call later easier
y = y.transpose(0, 1) # .contiguous() # (U + 1, B, H)
g, hid = self.prediction["dec_rnn"](y, state)
g = g.transpose(0, 1) # .contiguous() # (B, U + 1, H)
del y, start, state
return g, hid
def joint(self, f, g):
"""
f should be shape (B, T, H)
g should be shape (B, U + 1, H)
returns:
logits of shape (B, T, U, K + 1)
"""
# Combine the input states and the output states
B, T, H = f.shape
B, U_, H2 = g.shape
f = f.unsqueeze(dim=2) # (B, T, 1, H)
f = f.expand((B, T, U_, H))
g = g.unsqueeze(dim=1) # (B, 1, U + 1, H)
g = g.expand((B, T, U_, H2))
inp = torch.cat([f, g], dim=3) # (B, T, U, 2H)
res = self.joint_net(inp)
del f, g, inp
return res
class Encoder(torch.nn.Module):
def __init__(self, in_features, encoder_n_hidden,
encoder_pre_rnn_layers, encoder_post_rnn_layers,
forget_gate_bias, norm, rnn_type, encoder_stack_time_factor,
dropout):
super().__init__()
self.pre_rnn = rnn(
rnn=rnn_type,
input_size=in_features,
hidden_size=encoder_n_hidden,
num_layers=encoder_pre_rnn_layers,
norm=norm,
forget_gate_bias=forget_gate_bias,
dropout=dropout,
)
self.stack_time = StackTime(factor=encoder_stack_time_factor)
self.post_rnn = rnn(
rnn=rnn_type,
input_size=encoder_stack_time_factor * encoder_n_hidden,
hidden_size=encoder_n_hidden,
num_layers=encoder_post_rnn_layers,
norm=norm,
forget_gate_bias=forget_gate_bias,
norm_first_rnn=True,
dropout=dropout,
)
def forward(self, x: torch.Tensor, x_lens: torch.Tensor):
x, _ = self.pre_rnn(x, None)
x, x_lens = self.stack_time(x, x_lens)
x, _ = self.post_rnn(x, None)
x = x.transpose(0, 1)
return x, x_lens
def label_collate(labels):
"""Collates the label inputs for the rnn-t prediction network.
If `labels` is already in torch.Tensor form this is a no-op.
Args:
labels: A torch.Tensor List of label indexes or a torch.Tensor.
Returns:
A padded torch.Tensor of shape (batch, max_seq_len).
"""
if isinstance(labels, torch.Tensor):
return labels.type(torch.int64)
if not isinstance(labels, (list, tuple)):
raise ValueError(
f"`labels` should be a list or tensor not {type(labels)}"
)
batch_size = len(labels)
max_len = max(len(l) for l in labels)
cat_labels = np.full((batch_size, max_len), fill_value=0.0, dtype=np.int32)
for e, l in enumerate(labels):
cat_labels[e, :len(l)] = l
labels = torch.LongTensor(cat_labels)
return labels
```
---
[Visit
Topic](https://discuss.tvm.apache.org/t/import-rnn-t-pytorch-model-into-tvm/7874/6)
to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.tvm.apache.org/email/unsubscribe/7744e63c8ab2438b276e808229cef758a83d3010b48b0980fa2df8b3fea85150).