timprepscius commented on issue #20659:
URL: 
https://github.com/apache/incubator-mxnet/issues/20659#issuecomment-947035378


   I'm only doing one test now - have to work on other things.  Let me know if 
this works.  It doesn't for me whether on 1.x or master.
   
   I just ran again, using your  mx.contrib.onnx.import_to_gluon('onnx_test', 
ctx=mx.cpu()), which I didn't realize existed.
   
   
   
   Here it is, use complex=False
   
   instantiate like so:
   
   ```
   if test == 4:
       m = DCUnet20(complex=False)  
   ```
   
   input is like so:
   
   ```
   input = torch.ones((1, 1, 1537, 215))
   output = m(input).detach().numpy()
   
   ```
   
   if you do end up debugging, I find this useful:
   
   ```
   np.set_printoptions(threshold=sys.maxsize, precision=None, 
floatmode='unique')
   ```
   
   
   ```
   
   import torch
   import torch.nn as nn
   
   def log_(*args):
       # print(*args)
       pass
   
   class CConv2d(nn.Module):
       """
       Class of complex valued convolutional layer
       """
       def __init__(self, in_channels, out_channels, kernel_size, stride=1, 
padding=0):
           super().__init__()
           
           self.in_channels = in_channels
           self.out_channels = out_channels
           self.kernel_size = kernel_size
           self.padding = padding
           self.stride = stride
           
           self.real_conv = nn.Conv2d(in_channels=self.in_channels, 
                                      out_channels=self.out_channels, 
                                      kernel_size=self.kernel_size, 
                                      padding=self.padding, 
                                      stride=self.stride)
           
           self.im_conv = nn.Conv2d(in_channels=self.in_channels, 
                                    out_channels=self.out_channels, 
                                    kernel_size=self.kernel_size, 
                                    padding=self.padding, 
                                    stride=self.stride)
           
           # Glorot initialization.
           nn.init.xavier_uniform_(self.real_conv.weight)
           nn.init.xavier_uniform_(self.im_conv.weight)
           
           
       def forward(self, x):
           x_real = x[..., 0]
           x_im = x[..., 1]
           
           c_real = self.real_conv(x_real) - self.im_conv(x_im)
           c_im = self.im_conv(x_real) + self.real_conv(x_im)
           
           output = torch.stack([c_real, c_im], dim=-1)
           return output
   
   
   # + colab={} colab_type="code" id="GgtxJbSQ5i96"
   class CConvTranspose2d(nn.Module):
       """
         Class of complex valued dilation convolutional layer
       """
       def __init__(self, in_channels, out_channels, kernel_size, stride, 
output_padding=0, padding=0):
           super().__init__()
           
           self.in_channels = in_channels
   
           self.out_channels = out_channels
           self.kernel_size = kernel_size
           self.output_padding = output_padding
           self.padding = padding
           self.stride = stride
           
           self.real_convt = nn.ConvTranspose2d(in_channels=self.in_channels, 
                                               out_channels=self.out_channels, 
                                               kernel_size=self.kernel_size, 
                                               
output_padding=self.output_padding,
                                               padding=self.padding,
                                               stride=self.stride)
           
           self.im_convt = nn.ConvTranspose2d(in_channels=self.in_channels, 
                                               out_channels=self.out_channels, 
                                               kernel_size=self.kernel_size, 
                                               
output_padding=self.output_padding, 
                                               padding=self.padding,
                                               stride=self.stride)
           
           
           # Glorot initialization.
           nn.init.xavier_uniform_(self.real_convt.weight)
           nn.init.xavier_uniform_(self.im_convt.weight)
           
           
       def forward(self, x):
           x_real = x[..., 0]
           x_im = x[..., 1]
           
           ct_real = self.real_convt(x_real) - self.im_convt(x_im)
           ct_im = self.im_convt(x_real) + self.real_convt(x_im)
           
           output = torch.stack([ct_real, ct_im], dim=-1)
           return output
   
   
   # + colab={} colab_type="code" id="OJSmVrxp5i9-"
   class CBatchNorm2d(nn.Module):
       """
       Class of complex valued batch normalization layer
       """
       def __init__(self, num_features, eps=1e-05, momentum=0.1, affine=True, 
track_running_stats=True):
           super().__init__()
           
           self.num_features = num_features
           self.eps = eps
           self.momentum = momentum
           self.affine = affine
           self.track_running_stats = track_running_stats
           
           self.real_b = nn.BatchNorm2d(num_features=self.num_features, 
eps=self.eps, momentum=self.momentum,
                                         affine=self.affine, 
track_running_stats=self.track_running_stats)
           self.im_b = nn.BatchNorm2d(num_features=self.num_features, 
eps=self.eps, momentum=self.momentum,
                                       affine=self.affine, 
track_running_stats=self.track_running_stats) 
           
       def forward(self, x):
           x_real = x[..., 0]
           x_im = x[..., 1]
           
           n_real = self.real_b(x_real)
           n_im = self.im_b(x_im)  
           
           output = torch.stack([n_real, n_im], dim=-1)
           return output
   
   
   # + colab={} colab_type="code" id="N7W37XMO5i-B"
   class Encoder(nn.Module):
       """
       Class of upsample block
       """
       def __init__(self, filter_size=(7,5), stride_size=(2,2), in_channels=1, 
out_channels=45, padding=(0,0), complex=True):
           super().__init__()
           
           self.filter_size = filter_size
           self.stride_size = stride_size
           self.in_channels = in_channels
           self.out_channels = out_channels
           self.padding = padding
   
           if complex:
               self.cconv = CConv2d(in_channels=self.in_channels, 
out_channels=self.out_channels, 
                                   kernel_size=self.filter_size, 
stride=self.stride_size, padding=self.padding)
               self.cbn = CBatchNorm2d(num_features=self.out_channels) 
           else:
               self.cconv = nn.Conv2d(in_channels=self.in_channels, 
out_channels=self.out_channels, 
                                   kernel_size=self.filter_size, 
stride=self.stride_size, padding=self.padding)
           
               self.cbn = nn.BatchNorm2d(num_features=self.out_channels) 
           
           self.leaky_relu = nn.LeakyReLU()
               
       def forward(self, x):
           
           conved = self.cconv(x)
           normed = self.cbn(conved)
           acted = self.leaky_relu(normed)
           
           return acted
   
   
   # + colab={} colab_type="code" id="fuugYDZs5i-G"
   class Decoder(nn.Module):
       """
       Class of downsample block
       """
       def __init__(self, filter_size=(7,5), stride_size=(2,2), in_channels=1, 
out_channels=45,
                    output_padding=(0,0), padding=(0,0), last_layer=False, 
complex=True):
           super().__init__()
           
           self.filter_size = filter_size
           self.stride_size = stride_size
           self.in_channels = in_channels
           self.out_channels = out_channels
           self.output_padding = output_padding
           self.padding = padding
           self.complex = complex
           
           self.last_layer = last_layer
           
           if complex:
               self.cconvt = CConvTranspose2d(in_channels=self.in_channels, 
out_channels=self.out_channels, 
                                   kernel_size=self.filter_size, 
stride=self.stride_size, output_padding=self.output_padding, 
padding=self.padding)
           
               self.cbn = CBatchNorm2d(num_features=self.out_channels)
           else:
               self.cconvt = nn.ConvTranspose2d(in_channels=self.in_channels, 
out_channels=self.out_channels, 
                                   kernel_size=self.filter_size, 
stride=self.stride_size, output_padding=self.output_padding, 
padding=self.padding)
           
               self.cbn = nn.BatchNorm2d(num_features=self.out_channels)
   
           
           self.leaky_relu = nn.LeakyReLU()
               
       def forward(self, x):
           
           conved = self.cconvt(x)
           
           if self.complex:
               if not self.last_layer:
                   normed = self.cbn(conved)
                   output = self.leaky_relu(normed)
               else:
                   m_phase = conved / (torch.abs(conved) + 1e-8)
                   m_mag = torch.tanh(torch.abs(conved))
                   output = m_phase * m_mag
           else:
               normed = self.cbn(conved)
               output = self.leaky_relu(normed)
   
           return output
   
   class DCUnet20(nn.Module):
       """
       Deep Complex U-Net class of the model.
       """
       def __init__(self, complex=False):
           super().__init__()
   
           self.complex = complex
           
           # for istft
           self.set_size(model_complexity=int(45//1.414), input_channels=1, 
model_depth=20)
           self.encoders = []
           self.model_length = 20 // 2
           
           for i in range(self.model_length):
               module = Encoder(in_channels=self.enc_channels[i], 
out_channels=self.enc_channels[i + 1],
                                filter_size=self.enc_kernel_sizes[i], 
stride_size=self.enc_strides[i], padding=self.enc_paddings[i], complex=complex)
               self.add_module("encoder{}".format(i), module)
               self.encoders.append(module)
   
           self.decoders = []
   
           for i in range(self.model_length):
               if i != self.model_length - 1:
                   module = Decoder(in_channels=self.dec_channels[i] + 
self.enc_channels[self.model_length - i], out_channels=self.dec_channels[i + 
1], 
                                    filter_size=self.dec_kernel_sizes[i], 
stride_size=self.dec_strides[i], padding=self.dec_paddings[i],
                                    output_padding=self.dec_output_padding[i], 
complex=complex)
               else:
                   module = Decoder(in_channels=self.dec_channels[i] + 
self.enc_channels[self.model_length - i], out_channels=self.dec_channels[i + 
1], 
                                    filter_size=self.dec_kernel_sizes[i], 
stride_size=self.dec_strides[i], padding=self.dec_paddings[i],
                                    output_padding=self.dec_output_padding[i], 
last_layer=True, complex=complex)
               self.add_module("decoder{}".format(i), module)
               self.decoders.append(module)
          
           
       def forward(self, x):
           log_('x : ', x.shape)
           orig_x = x
           xs = []
           for i, encoder in enumerate(self.encoders):
               xs.append(x)
               x = encoder(x)
               log_('Encoder : ', x.shape)
               
           p = x
           for i, decoder in enumerate(self.decoders):
               p = decoder(p)
               if i == self.model_length - 1:
                   break
               log_('Decoder : ', p.shape)
               p = torch.cat([p, xs[self.model_length - 1 - i]], dim=1)
           
           # u9 - the mask
           
           mask = p
           
           log_('mask : ', mask.shape)
           
           output = mask * orig_x
           output = torch.squeeze(output, 1)
   
           return output
   
       
       def set_size(self, model_complexity, model_depth=20, input_channels=1):
   
           if model_depth == 20:
               self.enc_channels = [input_channels,
                                    model_complexity,
                                    model_complexity,
                                    model_complexity * 2,
                                    model_complexity * 2,
                                    model_complexity * 2,
                                    model_complexity * 2,
                                    model_complexity * 2,
                                    model_complexity * 2,
                                    model_complexity * 2,
                                    128]
   
               self.enc_kernel_sizes = [(7, 1),
                                        (1, 7),
                                        (6, 4),
                                        (7, 5),
                                        (5, 3),
                                        (5, 3),
                                        (5, 3),
                                        (5, 3),
                                        (5, 3),
                                        (5, 3)]
   
               self.enc_strides = [(1, 1),
                                   (1, 1),
                                   (2, 2),
                                   (2, 1),
                                   (2, 2),
                                   (2, 1),
                                   (2, 2),
                                   (2, 1),
                                   (2, 2),
                                   (2, 1)]
   
               self.enc_paddings = [(3, 0),
                                    (0, 3),
                                    (0, 0),
                                    (0, 0),
                                    (0, 0),
                                    (0, 0),
                                    (0, 0),
                                    (0, 0),
                                    (0, 0),
                                    (0, 0)]
   
               self.dec_channels = [0,
                                    model_complexity * 2,
                                    model_complexity * 2,
                                    model_complexity * 2,
                                    model_complexity * 2,
                                    model_complexity * 2,
                                    model_complexity * 2,
                                    model_complexity * 2,
                                    model_complexity,
                                    model_complexity,
                                    1]
   
               self.dec_kernel_sizes = [(6, 3), 
                                        (6, 3),
                                        (6, 3),
                                        (6, 4),
                                        (6, 3),
                                        (6, 4),
                                        (8, 5),
                                        (7, 5),
                                        (1, 7),
                                        (7, 1)]
   
               self.dec_strides = [(2, 1), #
                                   (2, 2), #
                                   (2, 1), #
                                   (2, 2), #
                                   (2, 1), #
                                   (2, 2), #
                                   (2, 1), #
                                   (2, 2), #
                                   (1, 1),
                                   (1, 1)]
   
               self.dec_paddings = [(0, 0),
                                    (0, 0),
                                    (0, 0),
                                    (0, 0),
                                    (0, 0),
                                    (0, 0),
                                    (0, 0),
                                    (0, 0),
                                    (0, 3),
                                    (3, 0)]
               
               self.dec_output_padding = [(0,0),
                                          (0,0),
                                          (0,0),
                                          (0,0),
                                          (0,0),
                                          (0,0),
                                          (0,0),
                                          (0,0),
                                          (0,0),
                                          (0,0)]
           else:
               raise ValueError("Unknown model depth : {}".format(model_depth))
   
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@mxnet.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@mxnet.apache.org
For additional commands, e-mail: issues-h...@mxnet.apache.org

Reply via email to