wms2537 opened a new issue #20123: URL: https://github.com/apache/incubator-mxnet/issues/20123
## Description The argpartition operator in mxnet.numpy is a fallback operator, it resets the context of the array passed to it to cpu. This results in a context mismatch between data and label when using TopKAccuracy in gluon metric. https://github.com/apache/incubator-mxnet/blob/5bc67e24b52ad4bbad0dda90e8ecc4a5a9544e2f/python/mxnet/gluon/metric.py#L506-L508 ## To Reproduce ``` >>> import mxnet as mx >>> a = mx.np.array([3, 4, 2, 1]).as_in_ctx(mx.gpu(0)) >>> a array([3., 4., 2., 1.], ctx=gpu(0)) >>> b = mx.np.argpartition(a, 3) >>> b.ctx cpu(0) ``` ### Fix Suggestion Change the context of `pred_label` after argpartition operator. ``` pred_label = numpy.argpartition(pred_label, -self.top_k).as_in_ctx(label.ctx) ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
