feevos commented on issue #19609:
URL:
https://github.com/apache/incubator-mxnet/issues/19609#issuecomment-774105316
Dear all, this is the final solution that is working for me, using legacy
operator mx.nd.GroupNorm
```Python
import mxnet as mx
from mxnet.gluon import HybridBlock
from mxnet.gluon.parameter import Parameter
@mx.use_np
class GroupNorm(HybridBlock):
r"""
Applies group normalization to the n-dimensional input array.
This operator takes an n-dimensional input array where the leftmost 2
axis are
`batch` and `channel` respectively:
.. math::
x = x.reshape((N, num_groups, C // num_groups, ...))
axis = (2, ...)
out = \frac{x - mean[x, axis]}{ \sqrt{Var[x, axis] + \epsilon}} *
gamma + beta
Parameters
----------
num_groups: int, default 1
Number of groups to separate the channel axis into.
epsilon: float, default 1e-5
Small float added to variance to avoid dividing by zero.
center: bool, default True
If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: bool, default True
If True, multiply by `gamma`. If False, `gamma` is not used.
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Initializer for the gamma weight.
Inputs:
- **data**: input tensor with shape (N, C, ...).
Outputs:
- **out**: output tensor with the same shape as `data`.
References
----------
`Group Normalization
<https://arxiv.org/pdf/1803.08494.pdf>`_
Examples
--------
# Input of shape (2, 3, 4)
x = mx.nd.array([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
# Group normalization is calculated with the above formula
layer = GroupNorm()
layer.initialize(ctx=mx.cpu(0))
layer(x)
[[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
[-0.4345239 -0.1448413 0.1448413 0.4345239]
[ 0.7242065 1.0138891 1.3035717 1.5932543]]
[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
[-0.4345239 -0.1448413 0.1448413 0.4345239]
[ 0.7242065 1.0138891 1.3035717 1.5932543]]]
<NDArray 2x3x4 @cpu(0)>
"""
def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
beta_initializer='zeros', gamma_initializer='ones',
in_channels=0):
super(GroupNorm, self).__init__()
self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center':
center, 'scale': scale}
self._num_groups = num_groups
self._epsilon = epsilon
self._center = center
self._scale = scale
self.gamma = Parameter('gamma', grad_req='write' if scale else
'null',
shape=(in_channels,), init=gamma_initializer,
allow_deferred_init=True)
self.beta = Parameter('beta', grad_req='write' if center else 'null',
shape=(in_channels,), init=beta_initializer,
allow_deferred_init=True)
def infer_shape(self,in_shape):
# Necessary for mxnet 2.0
tshape = in_shape.shape
self.gamma.shape = tshape[1],
self.beta.shape = tshape[1],
def forward(self, x):
gamma = self.gamma.data().as_nd_ndarray()
beta = self.beta.data().as_nd_ndarray()
x = mx.nd.GroupNorm(data=x.as_nd_ndarray(), gamma=gamma, beta=beta,
num_groups=self._num_groups, eps=self._epsilon)
x = x.as_np_ndarray()
return x
def __repr__(self):
s = '{name}({content}'
in_channels = self.gamma.shape[0]
s += ', in_channels={0}'.format(in_channels)
s += ')'
return s.format(name=self.__class__.__name__,
content=', '.join(['='.join([k, v.__repr__()])
for k, v in
self._kwargs.items()]))
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]