sxjscience closed pull request #10032: add axes support for dropouts in gluon URL: https://github.com/apache/incubator-mxnet/pull/10032
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py index d6402b769cb..b964c712ace 100644 --- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py +++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py @@ -180,16 +180,12 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N states = _get_begin_state(self, F, begin_state, inputs, batch_size) if self.drop_inputs: - first_input = inputs.slice_axis(axis, 0, 1).split(1, axis=axis, squeeze_axis=True) - self._initialize_input_masks(F, first_input, states) - inputs = F.broadcast_mul(inputs, self.drop_inputs_mask.expand_dims(axis=axis)) + inputs = F.Dropout(inputs, p=self.drop_inputs, axes=(axis,)) outputs, states = self.base_cell.unroll(length, inputs, states, layout, merge_outputs=True, valid_length=valid_length) if self.drop_outputs: - first_output = outputs.slice_axis(axis, 0, 1).split(1, axis=axis, squeeze_axis=True) - self._initialize_output_mask(F, first_output) - outputs = F.broadcast_mul(outputs, self.drop_outputs_mask.expand_dims(axis=axis)) + outputs = F.Dropout(outputs, p=self.drop_outputs, axes=(axis,)) merge_outputs = isinstance(outputs, tensor_types) if merge_outputs is None else \ merge_outputs outputs, _, _, _ = _format_sequence(length, outputs, layout, merge_outputs) diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index b61540dd61b..9dc1a240681 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -226,6 +226,8 @@ class Dropout(HybridBlock): ---------- rate : float Fraction of the input units to drop. Must be a number between 0 and 1. + axes : tuple of int, default () + The axes on which dropout mask is shared. If empty, regular dropout is applied. Inputs: @@ -239,15 +241,16 @@ class Dropout(HybridBlock): `Dropout: A Simple Way to Prevent Neural Networks from Overfitting <http://www.cs.toronto.edu/~rsalakhu/papers/srivastava14a.pdf>`_ """ - def __init__(self, rate, **kwargs): + def __init__(self, rate, axes=(), **kwargs): super(Dropout, self).__init__(**kwargs) self._rate = rate + self._axes = axes def hybrid_forward(self, F, x): - return F.Dropout(x, p=self._rate, name='fwd') + return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd') def __repr__(self): - s = '{name}(p = {_rate})' + s = '{name}(p = {_rate}, axes={_axes})' return s.format(name=self.__class__.__name__, **self.__dict__) diff --git a/python/mxnet/gluon/rnn/rnn_cell.py b/python/mxnet/gluon/rnn/rnn_cell.py index 61bf24e8cd1..f5c72f5f3e7 100644 --- a/python/mxnet/gluon/rnn/rnn_cell.py +++ b/python/mxnet/gluon/rnn/rnn_cell.py @@ -713,6 +713,8 @@ class DropoutCell(HybridRecurrentCell): rate : float Percentage of elements to drop out, which is 1 - percentage to retain. + axes : tuple of int, default () + The axes on which dropout mask is shared. If empty, regular dropout is applied. Inputs: @@ -723,13 +725,14 @@ class DropoutCell(HybridRecurrentCell): - **out**: output tensor with shape `(batch_size, size)`. - **next_states**: returns input `states` directly. """ - def __init__(self, rate, prefix=None, params=None): + def __init__(self, rate, axes=(), prefix=None, params=None): super(DropoutCell, self).__init__(prefix, params) assert isinstance(rate, numeric_types), "rate must be a number" - self.rate = rate + self._rate = rate + self._axes = axes def __repr__(self): - s = '{name}(rate = {rate})' + s = '{name}(rate={_rate}, axes={_axes})' return s.format(name=self.__class__.__name__, **self.__dict__) @@ -740,8 +743,9 @@ def _alias(self): return 'dropout' def hybrid_forward(self, F, inputs, states): - if self.rate > 0: - inputs = F.Dropout(data=inputs, p=self.rate, name='t%d_fwd'%self._counter) + if self._rate > 0: + inputs = F.Dropout(data=inputs, p=self._rate, axes=self._axes, + name='t%d_fwd'%self._counter) return inputs, states def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=None, diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 89f52154370..889d210da34 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -827,6 +827,46 @@ def selu(x): x = point_to_validate.reshape((1, 3, 2)) assert_almost_equal(prelu(x).asnumpy(), mx.nd.where(x >= 0, x, 0.25 * x).asnumpy()) +@with_seed() +def test_dropout(): + def get_slice(x, axis, idx): + ix = () + for i in range(x.ndim): + if i == axis: + ix += (idx,) + else: + ix += (slice(None, None, None),) + return x[ix] + + def check_dropout_axes(ratio, shape, axes): + compactshape = list(shape) + for axis in axes: + compactshape[axis] = 1 + compactx = mx.random.uniform(shape=tuple(compactshape)) + broadcastx = compactx.broadcast_to(shape) + dropouty = mx.gluon.nn.Dropout(rate=ratio, axes=axes)(broadcastx) + for axis in axes: + target = get_slice(dropouty, axis, 0).asnumpy() + for i in range(1, shape[axis]): + assert(get_slice(dropouty, axis, i).asnumpy() == target).all() + + nshape = (10, 10, 10, 10) + with mx.autograd.train_mode(): + check_dropout_axes(0.25, nshape, axes = (0,)) + check_dropout_axes(0.25, nshape, axes = (1,)) + check_dropout_axes(0.25, nshape, axes = (2,)) + check_dropout_axes(0.25, nshape, axes = (3,)) + check_dropout_axes(0.25, nshape, axes = (0, 1)) + check_dropout_axes(0.25, nshape, axes = (0, 2)) + check_dropout_axes(0.25, nshape, axes = (0, 3)) + check_dropout_axes(0.25, nshape, axes = (1, 2)) + check_dropout_axes(0.25, nshape, axes = (1, 3)) + check_dropout_axes(0.25, nshape, axes = (2, 3)) + check_dropout_axes(0.25, nshape, axes = (0, 1, 2)) + check_dropout_axes(0.25, nshape, axes = (0, 2, 3)) + check_dropout_axes(0.25, nshape, axes = (1, 2, 3)) + + if __name__ == '__main__': diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py index 03e4261ad16..29850dce6ae 100644 --- a/tests/python/unittest/test_gluon_contrib.py +++ b/tests/python/unittest/test_gluon_contrib.py @@ -120,11 +120,8 @@ def check_vardrop(drop_inputs, drop_states, drop_outputs): input_data = mx.nd.random_uniform(shape=(10, 3, 50), ctx=mx.context.current_context()) with mx.autograd.record(): outputs1, _ = cell.unroll(3, input_data, merge_outputs=True) - mask1 = cell.drop_outputs_mask.asnumpy() mx.nd.waitall() outputs2, _ = cell.unroll(3, input_data, merge_outputs=True) - mask2 = cell.drop_outputs_mask.asnumpy() - assert not almost_equal(mask1, mask2) assert not almost_equal(outputs1.asnumpy(), outputs2.asnumpy()) inputs = [mx.sym.Variable('rnn_t%d_data'%i) for i in range(3)] diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 91b8faa49c1..2208a33e801 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -268,7 +268,7 @@ def check_regression(symbol, forward, backward, shape, stype='default', densitie lambda x: x, lambda x, y : x - y, shape, stype='csr') - + def check_softmax_grad(xpu): x = mx.sym.Variable('x') @@ -4674,19 +4674,20 @@ def check_dropout_axes(ratio, shape, axes): check_dropout_ratio(0.25, shape) nshape = (10, 10, 10, 10) - check_dropout_axes(0.25, nshape, axes = (0,)) - check_dropout_axes(0.25, nshape, axes = (1,)) - check_dropout_axes(0.25, nshape, axes = (2,)) - check_dropout_axes(0.25, nshape, axes = (3,)) - check_dropout_axes(0.25, nshape, axes = (0, 1)) - check_dropout_axes(0.25, nshape, axes = (0, 2)) - check_dropout_axes(0.25, nshape, axes = (0, 3)) - check_dropout_axes(0.25, nshape, axes = (1, 2)) - check_dropout_axes(0.25, nshape, axes = (1, 3)) - check_dropout_axes(0.25, nshape, axes = (2, 3)) - check_dropout_axes(0.25, nshape, axes = (0, 1, 2)) - check_dropout_axes(0.25, nshape, axes = (0, 2, 3)) - check_dropout_axes(0.25, nshape, axes = (1, 2, 3)) + with mx.autograd.train_mode(): + check_dropout_axes(0.25, nshape, axes = (0,)) + check_dropout_axes(0.25, nshape, axes = (1,)) + check_dropout_axes(0.25, nshape, axes = (2,)) + check_dropout_axes(0.25, nshape, axes = (3,)) + check_dropout_axes(0.25, nshape, axes = (0, 1)) + check_dropout_axes(0.25, nshape, axes = (0, 2)) + check_dropout_axes(0.25, nshape, axes = (0, 3)) + check_dropout_axes(0.25, nshape, axes = (1, 2)) + check_dropout_axes(0.25, nshape, axes = (1, 3)) + check_dropout_axes(0.25, nshape, axes = (2, 3)) + check_dropout_axes(0.25, nshape, axes = (0, 1, 2)) + check_dropout_axes(0.25, nshape, axes = (0, 2, 3)) + check_dropout_axes(0.25, nshape, axes = (1, 2, 3)) @with_seed() ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services