eric-haibin-lin commented on a change in pull request #10371: [MXNET-263] Support for dot(dns, csr) = dns and dot(dns, csr.T) = dns on GPU URL: https://github.com/apache/incubator-mxnet/pull/10371#discussion_r181282253
########## File path: src/operator/tensor/dot-inl.cuh ########## @@ -442,6 +444,99 @@ struct DotCsrRspDnsScalarKernel { } }; +/*! + * \brief GPU Kernel to re-arrange nnz elements to csc order + * Parallelization by output elements: 1 thread/row of csr + */ +struct CscDataIndicesKernel { + template<typename DType, typename IType, typename CType> + __device__ __forceinline__ static void Map(int tid, + const DType* csr_data, + const IType* csr_indices, + const CType* csr_indptr, + DType* csc_data, + AtomicIType* csc_indices, + AtomicIType* csc_indptr, + AtomicIType* col_counters, + const nnvm::dim_t num_rows, + const nnvm::dim_t num_cols) { + if (tid < num_rows) { + for (CType i = csr_indptr[tid]; i < csr_indptr[tid + 1]; ++i) { + // target column + const IType target_col = csr_indices[i]; + const int target_offset = atomicAdd(&col_counters[target_col], 1); Review comment: I don't think this provides deterministic result.. The order of accumulation could be different across runs ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services