ZheyuYe opened a new issue #18717:
URL: https://github.com/apache/incubator-mxnet/issues/18717


   ## Description
   ```python
   import mxnet as mx
   from mxnet.gluon import HybridBlock, nn
   import tempfile
   import os
   mx.npx.set_np()
   
   
   class Foo(HybridBlock):
       def __init__(self, use_mlm=False):
           super().__init__()
           self.use_mlm = use_mlm
           self.vocab_size = 30522
           self.word_embed = nn.Embedding(input_dim=self.vocab_size,
                                          output_dim=64)
           
           if self.use_mlm:
               self.mlm_decoder = nn.HybridSequential()
               self.mlm_decoder.add(nn.Dense(units=64, flatten=False))
               self.mlm_decoder.add(nn.Dense(units=self.vocab_size, 
flatten=False))
               
self.mlm_decoder[-1].share_parameters(self.word_embed.collect_params())
   
       def hybrid_forward(self, F, x):
           x = self.word_embed(x)
           if self.use_mlm:
               x = self.mlm_decoder(x)
           return x
   
   foo = Foo(use_mlm=True)
   foo.initialize()
   foo(mx.np.ones((8,)))
   foo2 = Foo(use_mlm=False)
   with tempfile.TemporaryDirectory() as dir_path:
       foo.save_parameters(os.path.join(dir_path, 'test.params'),  
deduplicate=True)
       parametes = mx.npx.load(os.path.join(dir_path, 'test.params'))
       print(parametes.keys())
       foo2.load_parameters(os.path.join(dir_path, 'test.params'))
   ```
   Output:
   ```bash
   >>>dict_keys(['l2.weight', 'l2.bias'])
   >>>AssertionError: Parameter 'l1.weight' is missing in 'file: 
/tmp/tmp3a6xslz2/test.params', which contains parameters: 'l2.weight', 
'l2.bias'. Set allow_missing=True to ignore missing parameters.
   ```
   Here `l1` and `l2` are shared and thanks for the flag `deduplicate`, we 
could save shared paremeters only once as well as the dictionary correspondence 
using the last parameter name as key like `dict_keys(['l2.weight', 
'l2.bias'])`. There's nothing wrong with that unless we just load part 
parameters, as `foo2 = Foo(use_mlm=False)`.
   
   Of course we can solve this problem by calling L1 repeatedly instead of 
creating a separate layer `l2` sharing weights with `l1`. The following 
scenario is fairly common in pretraind model with masked language modelling as 
pretrained objective
   ```python
   import mxnet as mx
   from mxnet.gluon import HybridBlock, nn
   import tempfile
   import os
   mx.npx.set_np()
   
   
   class Foo(HybridBlock):
       def __init__(self, use_mlm=False):
           super().__init__()
           self.use_mlm = use_mlm
           self.vocab_size = 30522
           self.word_embed = nn.Embedding(input_dim=self.vocab_size,
                                          output_dim=64)
           
           if self.use_mlm:
               self.mlm_decoder = nn.HybridSequential()
               self.mlm_decoder.add(nn.Dense(units=64, flatten=False))
               self.mlm_decoder.add(nn.Dense(units=self.vocab_size, 
flatten=False))
               
self.mlm_decoder[-1].share_parameters(self.word_embed.collect_params())
   
       def hybrid_forward(self, F, x):
           x = self.word_embed(x)
           if self.use_mlm:
               x = self.mlm_decoder(x)
           return x
   
   foo = Foo(use_mlm=True)
   foo.initialize()
   foo(mx.np.ones((8,)))
   foo2 = Foo(use_mlm=False)
   with tempfile.TemporaryDirectory() as dir_path:
       foo.save_parameters(os.path.join(dir_path, 'test.params'), 
deduplicate=True)
       parametes = mx.npx.load(os.path.join(dir_path, 'test.params'))
       print(parametes.keys())
       foo2.load_parameters(os.path.join(dir_path, 'test.params'))
   ```
   
   ```bash
   >>>dict_keys(['mlm_decoder.1.weight', 'mlm_decoder.0.weight', 
'mlm_decoder.0.bias', 'mlm_decoder.1.bias'])
   ```
   Here `mlm_decoder` is only used in pretraining and woube be discard when 
fine-tuning down-stream tasks. In the `mlm_decoder`, we usually need to predict 
the masked token by mapping back to the `vocab_index` through a dense where 
parameters are shared with `word_embed`. However, saving in this way results in 
parameters without `word_embed.weight`.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to