Wallart opened a new issue #17256: Sparse compression causes errors URL: https://github.com/apache/incubator-mxnet/issues/17256 Hello everyone, I am trying to use sparse tensors to save memory in my Transformer architecture and I'm applying F.sparse.cast_storage on an attention weights tensor. ``` class ScaledDotProductAttn(gluon.HybridBlock): def __init__(self, dim_k, *args, **kwargs): super(ScaledDotProductAttn, self).__init__(*args, **kwargs) self._dim_k = dim_k def hybrid_forward(self, F, *args, **kwargs): query, key, value, mask, sparse_pattern = args matmul_qk = F.linalg.gemm2(query, key, transpose_b=True) # seq_len_q, seq_len_k scaled_attn_logits = matmul_qk / math.sqrt(self._dim_k) if mask is not None: scaled_attn_logits = F.broadcast_add(scaled_attn_logits, mask * -1e9) attn_weights = F.softmax(scaled_attn_logits) # seq_len_q, seq_len_k if sparse_pattern is not None: attn_weights = F.sparse.cast_storage(attn_weights * sparse_pattern, 'csr') output = F.linalg.gemm2(attn_weights, value) # seq_len_q, seq_len_k return output, attn_weights ``` As you can see the sparseNDArray is densified on the fly to produce output (because value is not sparse). Then, I return a dense output and a sparse attn_weights. Output will be finally used to compute the loss, and attn_weights for plotting if necessary. The error occurs when I'm updating the loss metric which is calling asnumpy internally. ``` Traceback (most recent call last): File "/home/wallart/workspaces/Transformer/trainer/transformer_trainer.py", line 77, in train self._loss_metric.update(0, [l * self._opts.batch_size for l in losses]) File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/metric.py", line 1687, in update loss = ndarray.sum(pred).asscalar() File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/ndarray/ndarray.py", line 2553, in asscalar return self.asnumpy()[0] File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/ndarray/ndarray.py", line 2535, in asnumpy ctypes.c_size_t(data.size))) File "/opt/miniconda3/envs/intelpython3/lib/python3.6/site-packages/mxnet-1.6.0-py3.6.egg/mxnet/base.py", line 255, in check_call raise MXNetError(py_str(_LIB.MXGetLastError())) mxnet.base.MXNetError: [10:01:05] src/operator/tensor/././cast_storage-inl.cuh:470: Check failed: dns.shape_.ndim() == 2 (4 vs. 2) ``` The issue occurs both on MXNet 1.5.1 and 1.6.0.rc0. Everything works if I disable the F.sparse.cast_storage call
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services