eric-haibin-lin commented on a change in pull request #10939: [MXNET-420] 
broadcast_mul/div between csr and 1D dense on GPU
URL: https://github.com/apache/incubator-mxnet/pull/10939#discussion_r194192949
 
 

 ##########
 File path: src/operator/tensor/elemwise_binary_broadcast_op.h
 ##########
 @@ -229,19 +226,28 @@ struct binary_broadcast_kernel {
   }
 };
 
-template<int req, typename OP>
+template<int req, typename OP, bool col_vec>
 struct csr_dns_csr_broadcast_kernel {
-  template <typename DType, typename CType, typename RType>
+  template<typename DType, typename CType, typename RType>
   MSHADOW_XINLINE static void Map(int row, const DType *csr_data, const CType 
*csr_indices,
                                   const RType *csr_indptr, const DType *dns,
-                                  DType *out, const nnvm::dim_t row_length, 
bool col_vec) {
+                                  DType *out, const nnvm::dim_t row_length) {
     const nnvm::dim_t curr_row_i = csr_indptr[row];
     const nnvm::dim_t next_row_i = csr_indptr[row + 1];
     for (nnvm::dim_t iter = curr_row_i; iter < next_row_i; iter++) {
       KERNEL_ASSIGN(out[iter], req, OP::Map(csr_data[iter],
                     (col_vec)? dns[row] : dns[csr_indices[iter]]));
     }
   }
+
+  template<typename DType>
+  MSHADOW_XINLINE static void Map(int i, const DType *csr_data, const DType* 
scalar_ptr,
 
 Review comment:
   Need documentation explaining that this kernel is always gonna read 
scalar_ptr[0]

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to