sxjscience opened a new issue #19472: URL: https://github.com/apache/incubator-mxnet/issues/19472
## Description Tried to use the GluonNLP BART model with amp but met with a weird error: What I'm doing is the following: - Create FP32 model - FP32 forward - Create FP16 model - Share the weights of FP32 model to FP16 model - FP16 forward - Turn on AMP - FP32 forward (Trigger error). To reproduce, you will first need to install the master version of GluonNLP: ``` git clone https://github.com/dmlc/gluon-nlp; cd gluon-nlp git checkout master python3 -m pip install --quiet -e .[extras] ``` ```python import mxnet as mx import numpy as np mx.npx.set_np() from gluonnlp.models.bart import BartModel cfg = BartModel.get_cfg() cfg.defrost() cfg.MODEL.vocab_size = 32 cfg.freeze() ctx = mx.gpu() batch_size = 4 src_length = 32 tgt_length = 16 with ctx: src_data = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, src_length), dtype=np.int32) src_valid_length = mx.np.random.randint(src_length // 2, src_length, (batch_size,), dtype=np.int32) tgt_data = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, tgt_length), dtype=np.int32) tgt_valid_length = mx.np.random.randint(tgt_length // 2, tgt_length, (batch_size, ), dtype=np.int32) model = BartModel.from_cfg(cfg) model.initialize(ctx=ctx) model.hybridize() model(src_data, src_valid_length, tgt_data, tgt_valid_length) mx.npx.waitall() model_fp16 = BartModel.from_cfg(cfg) model_fp16.share_parameters(model.collect_params()) model_fp16.cast('float16') model_fp16.hybridize() model_fp16(src_data, src_valid_length, tgt_data, tgt_valid_length) mx.npx.waitall() from mxnet import amp amp.init() model(src_data, src_valid_length, tgt_data, tgt_valid_length) mx.npx.waitall() trainer = mx.gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': 1E-3, 'wd': 1E-4, 'multi_precision': True}, update_on_kvstore=False) amp.init_trainer(trainer) with mx.autograd.record(): outputs_amp = model(src_data, src_valid_length, tgt_data, tgt_valid_length) if not isinstance(outputs_amp, (tuple, list)): loss = outputs_amp.mean() else: loss = sum([ele.mean() for ele in outputs_amp]) with amp.scale_loss(loss, trainer) as scaled_loss: mx.autograd.backward(scaled_loss) trainer.step(1) mx.npx.waitall() ``` ### Error Message ``` --------------------------------------------------------------------------- MXNetError Traceback (most recent call last) <ipython-input-1-73c1077ec25a> in <module> 33 model_fp16.cast('float16') 34 model_fp16.hybridize() ---> 35 model_fp16(src_data, src_valid_length, tgt_data, tgt_valid_length) 36 mx.npx.waitall() 37 from mxnet import amp ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, x, *args) 1431 1432 with x.ctx: -> 1433 return self._call_cached_op(x, *args) 1434 1435 def forward(self, x, *args): ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in _call_cached_op(self, *args) 1100 def _call_cached_op(self, *args): 1101 if self._cached_op is None: -> 1102 self._build_cache(*args) 1103 assert self._cached_op, "Gluon failed to build the cache. " \ 1104 "This should never happen. " \ ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in _build_cache(self, *args) 993 994 def _build_cache(self, *args): --> 995 data, out = self._get_graph(*args) 996 data_names = {data.name: i for i, data in enumerate(data)} 997 params = {p.var().name: p for p in self.collect_params().values()} ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in _get_graph(self, *args) 989 return self._get_graph_v1(*args) 990 else: # Gluon 2 based on deferred compute mode --> 991 return self._get_graph_v2(*args) 992 return self._cached_graph 993 ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in _get_graph_v2(self, *args) 978 args = _regroup(flatten_args, self._in_format) 979 with autograd.pause(), dc.context(): --> 980 out = super().__call__(*args) 981 flatten_out, self._out_format = _flatten(out, "output") 982 symbol_outputs = dc.get_symbol(flatten_out, sym_cls=type(symbol_inputs[0])) ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, *args) 709 hook(self, args) 710 --> 711 out = self.forward(*args) 712 713 for hook in self._forward_hooks.values(): ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/gluon-nlp/src/gluonnlp/models/bart.py in forward(self, src_data, src_valid_length, tgt_data, tgt_valid_length) 235 Shape (tgt_length, batch_size, tgt_vocab_size) 236 """ --> 237 enc_out = self.encode(src_data, src_valid_length) 238 contextual_embedding = self.decode_seq(tgt_data, tgt_valid_length, enc_out, 239 src_valid_length) ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/gluon-nlp/src/gluonnlp/models/transformer.py in encode(self, src_data, src_valid_length) 1177 npx.arange_like(src_data, axis=0)), axis=1) 1178 -> 1179 enc_out = self.encoder(src_data, src_valid_length) 1180 return enc_out 1181 ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, x, *args) 1428 # Deferred compute is already enabled. This typically means that the current 1429 # HybridBlock is a child block of a HybridBlock that has been hybridized. -> 1430 return super().__call__(x, *args) 1431 1432 with x.ctx: ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, *args) 709 hook(self, args) 710 --> 711 out = self.forward(*args) 712 713 for hook in self._forward_hooks.values(): ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/gluon-nlp/src/gluonnlp/models/transformer.py in forward(self, data, valid_length) 372 else: 373 layer = self.layers[i] --> 374 out, _ = layer(out, attn_mask) 375 if self._pre_norm: 376 out = self.ln_final(out) ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, x, *args) 1428 # Deferred compute is already enabled. This typically means that the current 1429 # HybridBlock is a child block of a HybridBlock that has been hybridized. -> 1430 return super().__call__(x, *args) 1431 1432 with x.ctx: ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, *args) 709 hook(self, args) 710 --> 711 out = self.forward(*args) 712 713 for hook in self._forward_hooks.values(): ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_shape(*args, **kwargs) 297 def _with_np_shape(*args, **kwargs): 298 with np_shape(active=True): --> 299 return func(*args, **kwargs) 300 return _with_np_shape 301 else: ~/.local/lib/python3.6/site-packages/mxnet/util.py in _with_np_array(*args, **kwargs) 478 def _with_np_array(*args, **kwargs): 479 with np_array(active=True): --> 480 return func(*args, **kwargs) 481 return _with_np_array 482 else: ~/gluon-nlp/src/gluonnlp/models/transformer.py in forward(self, data, attn_mask) 257 key = npx.reshape(key, (-2, -2, self._num_heads, -1)) 258 value = npx.reshape(value, (-2, -2, self._num_heads, -1)) --> 259 out, [_, attn_weight] = self.attention_cell(query, key, value, attn_mask) 260 out = self.attention_proj(out) 261 out = self.dropout_layer(out) ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, x, *args) 1428 # Deferred compute is already enabled. This typically means that the current 1429 # HybridBlock is a child block of a HybridBlock that has been hybridized. -> 1430 return super().__call__(x, *args) 1431 1432 with x.ctx: ~/.local/lib/python3.6/site-packages/mxnet/gluon/block.py in __call__(self, *args) 709 hook(self, args) 710 --> 711 out = self.forward(*args) 712 713 for hook in self._forward_hooks.values(): ~/gluon-nlp/src/gluonnlp/attention_cell.py in forward(self, query, key, value, mask, edge_scores) 643 query_head_units=self._query_head_units, 644 layout=self._layout, use_einsum=self._use_einsum, --> 645 dtype=self._dtype) 646 647 def __repr__(self): ~/gluon-nlp/src/gluonnlp/attention_cell.py in multi_head_dot_attn(query, key, value, mask, edge_scores, dropout, scaled, normalized, eps, query_head_units, layout, use_einsum, dtype) 540 else: 541 context_vec = npx.batch_dot(attn_weights, --> 542 np.swapaxes(value, 1, 2)).transpose((0, 2, 1, 3)) 543 context_vec = npx.reshape(context_vec, (-2, -2, -1)) 544 elif layout == 'TNK': ~/.local/lib/python3.6/site-packages/mxnet/ndarray/register.py in batch_dot(lhs, rhs, transpose_a, transpose_b, forward_stype, out, name, **kwargs) ~/.local/lib/python3.6/site-packages/mxnet/_ctypes/ndarray.py in _imperative_invoke(handle, ndargs, keys, vals, out, is_np_op, output_is_list) 89 c_str_array(keys), 90 c_str_array([str(s) for s in vals]), ---> 91 ctypes.byref(out_stypes))) 92 93 create_ndarray_fn = _global_var._np_ndarray_cls if is_np_op else _global_var._ndarray_cls ~/.local/lib/python3.6/site-packages/mxnet/base.py in check_call(ret) 244 """ 245 if ret != 0: --> 246 raise get_last_ffi_error() 247 248 MXNetError: Traceback (most recent call last): File "../src/io/../operator/elemwise_op_common.h", line 135 MXNetError: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node batch_dot at 1-th input: expected float32, got float16 ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
