Xingjian Shi created MXNET-507: ---------------------------------- Summary: Two problems in our ordering operators (topk, sort, argsort) Key: MXNET-507 URL: https://issues.apache.org/jira/browse/MXNET-507 Project: Apache MXNet Issue Type: Bug Reporter: Xingjian Shi
There are two problems in the ordering operators, i.e, topk, sort, argsort: 1) Only real_t is supported. 2) The indices are stored as real_t. This will cause error in the backward pass where the gradient are passed to the wrong locations. For example, the following code cannot be run in the previous version: ```python import mxnet as mx import numpy as np import mxnet.ndarray as nd ctx = mx.cpu() a = mx.nd.arange(54686454, ctx=ctx, dtype=np.int32) a.attach_grad() k = 10 with mx.autograd.record(): b = mx.nd.topk(a, k=k, ret_typ='value') b.backward(mx.nd.ones((k,), ctx=ctx, dtype=np.int32)) a_grad = a.grad.asnumpy() for i in range(-1, - k - 1, -1): assert a_grad[i] == 1 ``` I propose to fix this bug by changing the dtype of the indices to int32. However, this will make the code to be backward incompatible. -- This message was sent by Atlassian JIRA (v7.6.3#76005) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@mxnet.apache.org For additional commands, e-mail: issues-h...@mxnet.apache.org