learning-chip opened a new issue #20605:
URL: https://github.com/apache/incubator-mxnet/issues/20605


   ## Description
   
   Gradient of `CSRNDArray` is still a `CSRNDArray` type, but filled with 
non-zero values. Although this is mathematically correct (input is treated as 
dense matrix), it is not what a user would expect. Users would intuitively 
expect that zero entries are ignored, and only the gradient of non-zero entries 
are tracked. As of MXNet 1.8.0, there seems no way to take **sparse gradient** 
of `CSRNDArray`, which shoud keep the sparse pattern of the original matrix.
   
   As a comparison, in PyTorch (as of 1.8.1), gradient of CSR/COO sparse tensor 
follows the same sparse pattern of the original matrix (pytorch/pytorch#63744)
   
   ### Steps to reproduce
   
   In MXNet 1.8.0, run the following code:
   
   ```python
   import numpy as np
   import scipy.sparse as sp
   from mxnet import nd, autograd
   
   A = sp.diags([1, 2, 3], dtype='float64', format='csr')
   x = np.ones((3, 1))
   
   A_nd = nd.sparse.array(A, dtype=A.dtype)
   x_nd = nd.array(x, dtype=x.dtype)
   A_nd.attach_grad()
   x_nd.attach_grad()
   
   with autograd.record():
       Ax_nd = nd.sparse.dot(A_nd, x_nd)
       loss = Ax_nd.sum()
       loss.backward()
   
   print(x_nd.grad)  # shows [1, 2, 3], correct
   print(A_nd.grad.asnumpy())  # show a 3x3 dense matrix filled with 1.0
   print(A_nd.grad.data, A_nd.grad.indices, A_nd.grad.indptr)  # shows 9 
non-zero values, instead of 3 as in A
   ```
   
   Then I tried tracking the gradient of sparse values, before constructing the 
sparse matrix object:
   
   ```python
   # continue on the above code section
   indices = A.indices  # [0, 1, 2]
   indptr = A.indptr  # [0, 1, 2, 3]
   
   data_nd = nd.array([1, 2, 3], dtype='float64')
   data_nd.attach_grad()
   
   with autograd.record():
       B_nd = nd.sparse.csr_matrix((data_nd, indices, indptr))  # warning: lose 
track of gradient here!
       Bx_nd = nd.sparse.dot(B_nd, x_nd)
       loss = Bx_nd.sum()
       loss.backward()
   
   print(data_nd.grad)  # show [0, 0, 0], incorrect; should have been [1, 1, 1]
   ```
   
   As a comparison, in PyTorch (1.8.1), sparse gradient (w.r.t to values) can 
be computed by a similar logic:
   
   ```Python
   import torch
   
   A_value_th = torch.tensor([1, 2, 3], 
dtype=torch.float64).requires_grad_(True)
   A_th = torch.sparse_coo_tensor((row, col), A_value_th)
   print(A_th.to_dense())  # show diagonal matrix
   
   x_th = torch.ones((3, 1), dtype=torch.float64)
   Ax_th = torch.sparse.mm(A_th, x_th)
   loss = Ax_th.sum()
   loss.backward()
   print(A_value_th.grad)  # shows [1, 1, 1], correct
   ```
   
   Of course, you can also call `requires_grad_` on the sparse tensor (`A_th` 
here), and then the gradient will also be a sparse tensor, following the 
original sparse pattern 
(https://github.com/pytorch/pytorch/issues/63744#issuecomment-903858260).
   
   ## Environment
   
   ```
   mxnet==1.8.0
   torch==1.8.1
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to