I'm using recurrent neural networks from `mxnet.gluon.rnn` to build a simple 
language model:

```python
import mxnet as mx

class MyRNN(mx.gluon.HybridBlock):
    def __init__(self, vocab_dim, emb_dim, hidden_size, n_layers, **kwargs):
        super().__init__(**kwargs)
        with self.name_scope():
            self.embedding = mx.gluon.nn.Embedding(vocab_dim, emb_dim)
            self.rnn = mx.gluon.rnn.LSTM(hidden_size=hidden_size,
                                         num_layers=n_layers)
            self.output = mx.gluon.nn.Dense(vocab_dim, flatten=False)
    def hybrid_forward(self, f, x, states=None):
        if states is None:
            if f == mx.nd:
                states = self.rnn.begin_state(batch_size=x.shape[0],
                                              func=mx.nd.zeros)
            else:
                states = self.rnn.begin_state(func=mx.sym.zeros)
        x = f.transpose(x)
        x = self.embedding(x)
        x, states = self.rnn(x, states)
        return self.output(x).swapaxes(0, 1), states
```
For `Symbol` API, `begin_state` does not require `batch_size` (defaults to $0$) 
and in fact, we could not
infer the shape of input `x` or have an integer `batch_size` as a formal 
parameter of `hybrid_forward`. Still when hybridized, forward propagation 
initializes exactly zero-shaped list, and subsequent operations fail:
```python
>>> x = mx.nd.random.randint(0, 10, shape=(3, 5))
>>> rnn = MyRNN(10, 4, 8, 2)
>>> rnn.hybridize()
>>> rnn.initialize()
>>> rnn(x)
Traceback (most recent call last):
  ...
mxnet.gluon.parameter.DeferredInitializationError: Parameter 
'myrnn0_lstm0_l0_i2h_weight' has not been initialized yet because 
initialization was deferred. Actual initialization happens during the first 
forward pass. Please pass one batch of data through the network before 
accessing Parameters. You can also avoid deferred initialization by specifying 
in_units, num_features, etc., for network layers.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  ...
mxnet.base.MXNetError: MXNetError: Error in operator myrnn0_lstm0_rnn0: Shape 
inconsistent, Provided = [2,0,8], inferred shape=(2,3,8)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  ...
ValueError: Deferred initialization failed because shape cannot be inferred. 
MXNetError: Error in operator myrnn0_lstm0_rnn0: Shape inconsistent, Provided = 
[2,0,8], inferred shape=(2,3,8)
```

I guess the only way to avoid this is to initialize states out of 
`hybrid_forward` scope. Anyway, the error looks like a bug in 
`mxnet.gluon.rnn._RNNLayer`.





---
[Visit 
Topic](https://discuss.mxnet.apache.org/t/hybridized-rnn-state-initialization-error/6974/1)
 or reply to this email to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.mxnet.apache.org/email/unsubscribe/eb5c8dbfe9990111e5a2c8da5fd1a06a51ecbc830f542f9c3b294f366f25098b).

Reply via email to