feevos commented on issue #19609:
URL: 
https://github.com/apache/incubator-mxnet/issues/19609#issuecomment-774105316


   Dear all, this is the final solution that is working for me, using legacy 
operator mx.nd.GroupNorm
   
   ```Python
   import mxnet as mx
   from mxnet.gluon import HybridBlock
   from mxnet.gluon.parameter import Parameter
   
   @mx.use_np
   class GroupNorm(HybridBlock):
       r"""
       Applies group normalization to the n-dimensional input array.
       This operator takes an n-dimensional input array where the leftmost 2 
axis are
       `batch` and `channel` respectively:
       .. math::
         x = x.reshape((N, num_groups, C // num_groups, ...))
         axis = (2, ...)
         out = \frac{x - mean[x, axis]}{ \sqrt{Var[x, axis] + \epsilon}} * 
gamma + beta
       Parameters
       ----------
       num_groups: int, default 1
           Number of groups to separate the channel axis into.
       epsilon: float, default 1e-5
           Small float added to variance to avoid dividing by zero.
       center: bool, default True
           If True, add offset of `beta` to normalized tensor.
           If False, `beta` is ignored.
       scale: bool, default True
           If True, multiply by `gamma`. If False, `gamma` is not used.
       beta_initializer: str or `Initializer`, default 'zeros'
           Initializer for the beta weight.
       gamma_initializer: str or `Initializer`, default 'ones'
           Initializer for the gamma weight.
       Inputs:
           - **data**: input tensor with shape (N, C, ...).
       Outputs:
           - **out**: output tensor with the same shape as `data`.
       References
       ----------
           `Group Normalization
           <https://arxiv.org/pdf/1803.08494.pdf>`_
       Examples
       --------
       # Input of shape (2, 3, 4)
       x = mx.nd.array([[[ 0,  1,  2,  3],
                             [ 4,  5,  6,  7],
                             [ 8,  9, 10, 11]],
                            [[12, 13, 14, 15],
                             [16, 17, 18, 19],
                             [20, 21, 22, 23]]])
       # Group normalization is calculated with the above formula
       layer = GroupNorm()
       layer.initialize(ctx=mx.cpu(0))
       layer(x)
       [[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
         [-0.4345239 -0.1448413  0.1448413  0.4345239]
         [ 0.7242065  1.0138891  1.3035717  1.5932543]]
        [[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
         [-0.4345239 -0.1448413  0.1448413  0.4345239]
         [ 0.7242065  1.0138891  1.3035717  1.5932543]]]
       <NDArray 2x3x4 @cpu(0)>
       """
       def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
                    beta_initializer='zeros', gamma_initializer='ones',
                    in_channels=0):
           super(GroupNorm, self).__init__()
           self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': 
center, 'scale': scale}
           self._num_groups = num_groups
           self._epsilon = epsilon
           self._center = center
           self._scale = scale
           self.gamma = Parameter('gamma', grad_req='write' if scale else 
'null',
                                  shape=(in_channels,), init=gamma_initializer,
                                  allow_deferred_init=True)
           self.beta = Parameter('beta', grad_req='write' if center else 'null',
                                 shape=(in_channels,), init=beta_initializer,
                                 allow_deferred_init=True)
   
       def infer_shape(self,in_shape):
           # Necessary for mxnet 2.0 
           tshape = in_shape.shape
           self.gamma.shape = tshape[1],
           self.beta.shape = tshape[1],
   
       def forward(self, x):
   
           gamma = self.gamma.data().as_nd_ndarray()
           beta = self.beta.data().as_nd_ndarray()
           x = mx.nd.GroupNorm(data=x.as_nd_ndarray(), gamma=gamma, beta=beta, 
num_groups=self._num_groups, eps=self._epsilon)
   
           x = x.as_np_ndarray()
           return x
   
   
       def __repr__(self):
           s = '{name}({content}'
           in_channels = self.gamma.shape[0]
           s += ', in_channels={0}'.format(in_channels)
           s += ')'
           return s.format(name=self.__class__.__name__,
                           content=', '.join(['='.join([k, v.__repr__()])
                                              for k, v in 
self._kwargs.items()]))
   
   
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to