I'm using recurrent neural networks from `mxnet.gluon.rnn` to build a simple
language model:
```python
import mxnet as mx
class MyRNN(mx.gluon.HybridBlock):
def __init__(self, vocab_dim, emb_dim, hidden_size, n_layers, **kwargs):
super().__init__(**kwargs)
with self.name_scope():
self.embedding = mx.gluon.nn.Embedding(vocab_dim, emb_dim)
self.rnn = mx.gluon.rnn.LSTM(hidden_size=hidden_size,
num_layers=n_layers)
self.output = mx.gluon.nn.Dense(vocab_dim, flatten=False)
def hybrid_forward(self, f, x, states=None):
if states is None:
if f == mx.nd:
states = self.rnn.begin_state(batch_size=x.shape[0],
func=mx.nd.zeros)
else:
states = self.rnn.begin_state(func=mx.sym.zeros)
x = f.transpose(x)
x = self.embedding(x)
x, states = self.rnn(x, states)
return self.output(x).swapaxes(0, 1), states
```
For `Symbol` API, `begin_state` does not require `batch_size` (defaults to $0$)
and in fact, we could not
infer the shape of input `x` or have an integer `batch_size` as a formal
parameter of `hybrid_forward`. Still when hybridized, forward propagation
initializes exactly zero-shaped list, and subsequent operations fail:
```python
>>> x = mx.nd.random.randint(0, 10, shape=(3, 5))
>>> rnn = MyRNN(10, 4, 8, 2)
>>> rnn.hybridize()
>>> rnn.initialize()
>>> rnn(x)
Traceback (most recent call last):
...
mxnet.gluon.parameter.DeferredInitializationError: Parameter
'myrnn0_lstm0_l0_i2h_weight' has not been initialized yet because
initialization was deferred. Actual initialization happens during the first
forward pass. Please pass one batch of data through the network before
accessing Parameters. You can also avoid deferred initialization by specifying
in_units, num_features, etc., for network layers.
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
...
mxnet.base.MXNetError: MXNetError: Error in operator myrnn0_lstm0_rnn0: Shape
inconsistent, Provided = [2,0,8], inferred shape=(2,3,8)
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
...
ValueError: Deferred initialization failed because shape cannot be inferred.
MXNetError: Error in operator myrnn0_lstm0_rnn0: Shape inconsistent, Provided =
[2,0,8], inferred shape=(2,3,8)
```
I guess the only way to avoid this is to initialize states out of
`hybrid_forward` scope. Anyway, the error looks like a bug in
`mxnet.gluon.rnn._RNNLayer`.
---
[Visit
Topic](https://discuss.mxnet.apache.org/t/hybridized-rnn-state-initialization-error/6974/1)
or reply to this email to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.mxnet.apache.org/email/unsubscribe/eb5c8dbfe9990111e5a2c8da5fd1a06a51ecbc830f542f9c3b294f366f25098b).