haojin2 commented on a change in pull request #14445: Speedup SequenceMask on 
GPU
URL: https://github.com/apache/incubator-mxnet/pull/14445#discussion_r268354211
 
 

 ##########
 File path: src/operator/sequence_mask-inl.h
 ##########
 @@ -65,70 +65,24 @@ struct SequenceMaskParam : public 
dmlc::Parameter<SequenceMaskParam> {
   }
 };
 
-// (seqlen, batch, rest) case
-template <int req>
-struct SequenceMask0Kernel {
-  template <typename DType, typename IType>
-  MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx,
-                                  index_t max_s_len, index_t batch_size,
-                                  index_t restsize, DType value) {
-    const index_t seqpos = static_cast<int>(idx[b]);
-#pragma unroll
-    for (index_t s = seqpos; s < max_s_len; ++s) {
-      index_t incr = (s * batch_size * restsize) + (b * restsize);
-#pragma unroll
-      for (index_t r = 0; r < restsize; ++r)
-        KERNEL_ASSIGN(in[incr + r], req, value);
-    }
-  }
-};
-
-// (batch, seqlen, rest) case
-template <int req>
-struct SequenceMask1Kernel {
-  template <typename DType, typename IType>
-  MSHADOW_XINLINE static void Map(int b, DType *in, const IType *idx,
-                                  index_t max_s_len, index_t batch_size,
-                                  index_t restsize, DType value) {
-    const index_t seqpos = static_cast<int>(idx[b]);
-#pragma unroll
-    for (index_t s = seqpos; s < max_s_len; ++s) {
-      index_t incr = (b * max_s_len * restsize) + (s * restsize);
-#pragma unroll
-      for (index_t r = 0; r < restsize; ++r)
-        KERNEL_ASSIGN(in[incr + r], req, value);
-    }
-  }
-};
+template<typename DType, typename IType>
+void SequenceMaskExec(const mshadow::Tensor<cpu, 3, DType> &data,
+                  const mshadow::Tensor<cpu, 1, IType> &indices,
+                  const OpReqType req, mshadow::Stream<cpu> *const s,
+                  int axis, DType val);
+#ifdef __CUDACC__
+template<typename DType, typename IType>
 
 Review comment:
   Tried that and failed, not possible to do partial specialization under such 
case.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to