This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new c02c6b1 fill parameter shape (#8528) c02c6b1 is described below commit c02c6b149b40bc9a8db91c95453ff0e96f3edc3c Author: Sheng Zha <s...@users.noreply.github.com> AuthorDate: Fri Nov 3 10:34:57 2017 -0700 fill parameter shape (#8528) --- python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py | 5 ++-- python/mxnet/gluon/nn/basic_layers.py | 8 +++--- python/mxnet/gluon/nn/conv_layers.py | 5 ++-- python/mxnet/gluon/parameter.py | 7 +++-- python/mxnet/gluon/rnn/rnn_cell.py | 38 ++++++++++++++++++------- python/mxnet/gluon/rnn/rnn_layer.py | 4 +-- tests/python/unittest/test_gluon.py | 36 +++++++++++++++++++++++ tests/python/unittest/test_gluon_contrib.py | 7 +++++ tests/python/unittest/test_gluon_rnn.py | 13 +++++++++ 9 files changed, 99 insertions(+), 24 deletions(-) diff --git a/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py b/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py index cbb3f1a..09db547 100644 --- a/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py +++ b/python/mxnet/gluon/contrib/rnn/conv_rnn_cell.py @@ -131,8 +131,9 @@ class _BaseConvRNNCell(HybridRecurrentCell): s += ', {_conv_layout}' s += ')' attrs = self.__dict__ - mapping = ('{_in_channels} -> {_hidden_channels}'.format(**attrs) if self._in_channels - else self._hidden_channels) + shape = self.i2h_weight.shape + in_channels = shape[1 if self._channel_axis == 1 else -1] + mapping = ('{0} -> {1}'.format(in_channels if in_channels else None, shape[0])) return s.format(name=self.__class__.__name__, mapping=mapping, **attrs) diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index e9fb2ff..906f03e 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -207,10 +207,10 @@ class Dense(HybridBlock): def __repr__(self): s = '{name}({layout}, {act})' + shape = self.weight.shape return s.format(name=self.__class__.__name__, act=self.act if self.act else 'linear', - layout='{0} -> {1}'.format(self._in_units, self._units) if self._in_units - else self._units) + layout='{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0])) class Activation(HybridBlock): @@ -360,8 +360,8 @@ class BatchNorm(HybridBlock): def __repr__(self): s = '{name}({content}' - if hasattr(self, 'in_channels'): - s += ', in_channels={0}'.format(self.in_channels) + in_channels = self.gamma.shape[0] + s += ', in_channels={0}'.format(in_channels if in_channels else None) s += ')' return s.format(name=self.__class__.__name__, content=', '.join(['='.join([k, v.__repr__()]) diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py index 8dcdbc3..645de98 100644 --- a/python/mxnet/gluon/nn/conv_layers.py +++ b/python/mxnet/gluon/nn/conv_layers.py @@ -153,10 +153,9 @@ class _Conv(HybridBlock): if self.bias is None: s += ', bias=False' s += ')' + shape = self.weight.shape return s.format(name=self.__class__.__name__, - mapping=self._channels if not self._in_channels - else '{0} -> {1}'.format(self._in_channels, - self._channels), + mapping='{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0]), **self._kwargs) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index c73aee2..c42fbaa 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -171,11 +171,12 @@ class Parameter(object): def _load_init(self, data, ctx): """(Re)initializes by loading from data.""" if self.shape: - for i, j in zip(self.shape, data.shape): - assert i == 0 or i == j, \ + for self_dim, data_dim in zip(self.shape, data.shape): + assert self_dim == 0 or self_dim == data_dim, \ "Failed loading Parameter %s from saved params: " \ "shape incompatible expacted %s vs saved %s"%( self.name, str(self.shape), str(data.shape)) + self.shape = tuple(i if i != 0 else j for i, j in zip(self.shape, data.shape)) if self.dtype: assert np.dtype(self.dtype).type == data.dtype, \ "Failed loading Parameter %s from saved params: " \ @@ -344,6 +345,8 @@ class Parameter(object): "Parameter %s has not been initialized"%self.name for arr in self.list_data(): arr[:] = data + if not self.shape or np.prod(self.shape) <= 0: + self.shape = data.shape def data(self, ctx=None): """Returns a copy of this parameter on one context. Must have been diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py index 9d318eb..ea0e32f 100644 --- a/python/mxnet/gluon/rnn/rnn_cell.py +++ b/python/mxnet/gluon/rnn/rnn_cell.py @@ -111,17 +111,6 @@ class RecurrentCell(Block): self._modified = False self.reset() - def __repr__(self): - s = '{name}({mapping}' - if hasattr(self, '_activation'): - s += ', {_activation}' - s += ')' - mapping = ('{_input_size} -> {_hidden_size}'.format(**self.__dict__) if self._input_size - else self._hidden_size) - return s.format(name=self.__class__.__name__, - mapping=mapping, - **self.__dict__) - def reset(self): """Reset before re-using the cell for another graph.""" self._init_counter = -1 @@ -355,6 +344,17 @@ class RNNCell(HybridRecurrentCell): def _alias(self): return 'rnn' + def __repr__(self): + s = '{name}({mapping}' + if hasattr(self, '_activation'): + s += ', {_activation}' + s += ')' + shape = self.i2h_weight.shape + mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0]) + return s.format(name=self.__class__.__name__, + mapping=mapping, + **self.__dict__) + def hybrid_forward(self, F, inputs, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): prefix = 't%d_'%self._counter @@ -453,6 +453,14 @@ class LSTMCell(HybridRecurrentCell): def _alias(self): return 'lstm' + def __repr__(self): + s = '{name}({mapping})' + shape = self.i2h_weight.shape + mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0]) + return s.format(name=self.__class__.__name__, + mapping=mapping, + **self.__dict__) + def hybrid_forward(self, F, inputs, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): prefix = 't%d_'%self._counter @@ -551,6 +559,14 @@ class GRUCell(HybridRecurrentCell): def _alias(self): return 'gru' + def __repr__(self): + s = '{name}({mapping})' + shape = self.i2h_weight.shape + mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0]) + return s.format(name=self.__class__.__name__, + mapping=mapping, + **self.__dict__) + def hybrid_forward(self, F, inputs, states, i2h_weight, h2h_weight, i2h_bias, h2h_bias): # pylint: disable=too-many-locals diff --git a/python/mxnet/gluon/rnn/rnn_layer.py b/python/mxnet/gluon/rnn/rnn_layer.py index 2d7c008..3a4f712 100644 --- a/python/mxnet/gluon/rnn/rnn_layer.py +++ b/python/mxnet/gluon/rnn/rnn_layer.py @@ -89,8 +89,8 @@ class _RNNLayer(Block): if self._dir == 2: s += ', bidirectional' s += ')' - mapping = ('{_input_size} -> {_hidden_size}'.format(**self.__dict__) if self._input_size - else self._hidden_size) + shape = self.i2h_weight[0].shape + mapping = '{0} -> {1}'.format(shape[1] if shape[1] else None, shape[0]) return s.format(name=self.__class__.__name__, mapping=mapping, **self.__dict__) diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 6f9966b..df0af34 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -553,8 +553,44 @@ def test_lambda(): assert_almost_equal(out1.asnumpy(), out3.asnumpy()) +def test_fill_shape_deferred(): + net = nn.HybridSequential() + with net.name_scope(): + net.add(nn.Conv2D(64, kernel_size=2, padding=1), + nn.BatchNorm(), + nn.Dense(10)) + net.hybridize() + net.initialize() + net(mx.nd.ones((2,3,5,7))) + assert net[0].weight.shape[1] == 3, net[0].weight.shape[1] + assert net[1].gamma.shape[0] == 64, net[1].gamma.shape[0] + assert net[2].weight.shape[1] == 3072, net[2].weight.shape[1] +def test_fill_shape_load(): + ctx = mx.context.current_context() + net1 = nn.HybridSequential() + with net1.name_scope(): + net1.add(nn.Conv2D(64, kernel_size=2, padding=1), + nn.BatchNorm(), + nn.Dense(10)) + net1.hybridize() + net1.initialize(ctx=ctx) + net1(mx.nd.ones((2,3,5,7), ctx)) + net1.save_params('net_fill.params') + + net2 = nn.HybridSequential() + with net2.name_scope(): + net2.add(nn.Conv2D(64, kernel_size=2, padding=1), + nn.BatchNorm(), + nn.Dense(10)) + net2.hybridize() + net2.initialize() + net2.load_params('net_fill.params', ctx) + assert net2[0].weight.shape[1] == 3, net2[0].weight.shape[1] + assert net2[1].gamma.shape[0] == 64, net2[1].gamma.shape[0] + assert net2[2].weight.shape[1] == 3072, net2[2].weight.shape[1] + if __name__ == '__main__': import nose diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py index c99836c..07b8956 100644 --- a/tests/python/unittest/test_gluon_contrib.py +++ b/tests/python/unittest/test_gluon_contrib.py @@ -94,6 +94,13 @@ def test_convgru(): check_rnn_cell(cell, prefix='rnn_', in_shape=(1, 10, 20, 30, 50), out_shape=(1, 100, 18, 28, 48)) +def test_conv_fill_shape(): + cell = contrib.rnn.Conv1DLSTMCell((0, 7), 10, (3,), (3,)) + cell.hybridize() + check_rnn_forward(cell, mx.nd.ones((8, 3, 5, 7))) + assert cell.i2h_weight.shape[1] == 5, cell.i2h_weight.shape[1] + + def test_vardrop(): def check_vardrop(drop_inputs, drop_states, drop_outputs): cell = contrib.rnn.VariationalDropoutCell(mx.gluon.rnn.RNNCell(100, prefix='rnn_'), diff --git a/tests/python/unittest/test_gluon_rnn.py b/tests/python/unittest/test_gluon_rnn.py index f71ac18..2288842 100644 --- a/tests/python/unittest/test_gluon_rnn.py +++ b/tests/python/unittest/test_gluon_rnn.py @@ -274,6 +274,19 @@ def test_rnn_layers(): with mx.autograd.record(): net(mx.nd.ones((2, 3, 10))).backward() +def test_cell_fill_shape(): + cell = gluon.rnn.LSTMCell(10) + cell.hybridize() + check_rnn_forward(cell, mx.nd.ones((2, 3, 7))) + assert cell.i2h_weight.shape[1] == 7, cell.i2h_weight.shape[1] + +def test_layer_fill_shape(): + layer = gluon.rnn.LSTM(10) + layer.hybridize() + check_rnn_layer_forward(layer, mx.nd.ones((3, 2, 7))) + print(layer) + assert layer.i2h_weight[0].shape[1] == 7, layer.i2h_weight[0].shape[1] + if __name__ == '__main__': import nose -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].