nickguletskii commented on a change in pull request #16790: migrate cudaMemcpy to cudaMemcpyAsync+cudaStreamSynchronize URL: https://github.com/apache/incubator-mxnet/pull/16790#discussion_r345849814
########## File path: src/operator/contrib/proposal.cu ########## @@ -552,8 +557,8 @@ class ProposalGPUOp : public Operator{ // copy nms result to gpu int* keep; FRCNN_CUDA_CHECK(cudaMalloc(&keep, sizeof(int) * _keep.size())); - FRCNN_CUDA_CHECK(cudaMemcpy(keep, &_keep[0], sizeof(int) * _keep.size(), - cudaMemcpyHostToDevice)); + FRCNN_CUDA_CHECK(cudaMemcpyAsync(keep, &_keep[0], sizeof(int) * _keep.size(), + cudaMemcpyHostToDevice, stream)); Review comment: Here, memcpy is being scheduled to run on `stream`, but [line 566](https://github.com/apache/incubator-mxnet/blob/743c10c3d931b84de71a3731a69c6494921dd43d/src/operator/contrib/proposal.cu#L566) runs `PrepareOutput` on the default context. Unless I am missing something, this introduces a potential race condition: when `stream` does not match the default stream, `PrepareOutput` may run before `cudaMemcpyAsync` completes. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services