LeiWang1999 commented on code in PR #15462:
URL: https://github.com/apache/tvm/pull/15462#discussion_r1287053938
##########
src/runtime/contrib/cudnn/conv_forward.cc:
##########
@@ -132,17 +133,61 @@ void FindAlgo(int format, int dims, int groups, const int
pad[], const int strid
"CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED"};
auto best_algo = perf_results[0].algo;
- LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " fwd algorithms,
choosing "
- << fwd_algo_names[best_algo];
- for (int i = 0; i < returned_algo_count; ++i) {
- LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo]
- << " - time: " << perf_results[i].time << " ms"
- << ", Memory: " << perf_results[i].memory;
+ if (verbose) {
+ LOG(INFO) << "\tCUDNN Found " << returned_algo_count << " fwd algorithms,
choosing "
+ << fwd_algo_names[best_algo];
+ for (int i = 0; i < returned_algo_count; ++i) {
+ LOG(INFO) << "\t\t" << i << ") " << fwd_algo_names[perf_results[i].algo]
+ << " - time: " << perf_results[i].time << " ms"
+ << ", Memory: " << perf_results[i].memory;
+ }
}
ret[0] = best_algo;
}
+void CallCudnnConvolutionForward(cudnnHandle_t handle, cudaStream_t stream,
int mode, int format,
+ int algo, int dims, int groups, const int
pad[],
+ const int stride[], const int dilation[],
const DLTensor* x,
+ const DLTensor* w, const DLTensor* y,
+ const std::string& conv_dtype) {
+ CuDNNThreadEntry* entry_ptr = CuDNNThreadEntry::ThreadLocal();
+
+ // Set Mode
+ entry_ptr->conv_entry.mode = static_cast<cudnnConvolutionMode_t>(mode);
+
+ // Set Descriptors
+ SetConvDescriptors(entry_ptr, format, dims, groups, pad, stride, dilation,
x->shape, w->shape,
+ y->shape, x->dtype, conv_dtype);
+
+ // Set Device
+ entry_ptr->conv_entry.device = x->device;
+
+ // Set Algo
+ entry_ptr->conv_entry.fwd_algo =
static_cast<cudnnConvolutionFwdAlgo_t>(algo);
+
+ // Set workspace
+ size_t workspace_size = 0;
+ CUDNN_CALL(cudnnGetConvolutionForwardWorkspaceSize(
+ handle, entry_ptr->conv_entry.input_desc,
entry_ptr->conv_entry.filter_desc,
+ entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.output_desc,
+ entry_ptr->conv_entry.fwd_algo, &workspace_size));
+
+ entry_ptr->conv_entry.UpdateWorkspace(workspace_size);
+
+ // Compute convolution
+ CUDNN_CALL(cudnnConvolutionForward(
+ handle, CuDNNDataType::GetConst<1>(entry_ptr->conv_entry.data_type),
+ entry_ptr->conv_entry.input_desc, x->data,
entry_ptr->conv_entry.filter_desc, w->data,
+ entry_ptr->conv_entry.conv_desc, entry_ptr->conv_entry.fwd_algo,
+ entry_ptr->conv_entry.workspace, workspace_size,
+ CuDNNDataType::GetConst<0>(entry_ptr->conv_entry.data_type),
+ entry_ptr->conv_entry.output_desc, y->data));
+
+ // Set the stream to be used by cuDNN
+ cudnnSetStream(handle, stream);
Review Comment:
typo....
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]