This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new b8e8d73  Fix and optimize handling of vectorized memory accesses 
(#17767) (#18095)
b8e8d73 is described below

commit b8e8d7377a7986cfcadf8dd992391a0e615ff7f0
Author: Przemyslaw Tredak <ptre...@nvidia.com>
AuthorDate: Sat Apr 18 16:24:39 2020 -0700

    Fix and optimize handling of vectorized memory accesses (#17767) (#18095)
    
    * Vectorized loads for binary elemwise kernel
    
    * More generalization
    
    * Add backwardusenone
    
    * Remove the unused _backward_add op
    
    * Add vectorized backwardusein
    
    * Extending vectorization to more binary ops, binary ops with scalar and
    unary ops
    
    * Handling ElementwiseSum
    
    * Get rid of half2 in mshadow
    
    * Remove backward_elemwiseaddex
    
    * Revert "Remove the unused _backward_add op"
    
    This reverts commit f86da86f809c8cbad07db76a3554f23890fe05a3.
    
    * Revert "Remove backward_elemwiseaddex"
    
    This reverts commit 7729114caf6a1718c08ce1f35529d2267057d515.
    
    * Add back the backward_add since C++ test relies on it
    
    * Test bcast implementations
    
    * First version of vecotrized bcast
    
    * Adding single side vectorized bcast kernel
    
    * Removing debug prints
    
    * Actually run the single side kernel
    
    * Move the default implementation of bcast to the vectorized one
    
    * Limit the new implementation to GPU only
    
    * Enabling vectorization when broadcast does not actually do broadcast
    
    * Cleaning
    
    * Cleaning part 2
    
    * Fix for numpy ops using stuff from broadcast
    
    * Fix
    
    * Fix lint
    
    * Try to debug pinv numpy test
    
    * Fix
    
    * Fix the vectorized broadcast implementation for misaligned input
    pointers
    
    * Added tests
    
    * Added docs to cuda_vectorization.cuh
    
    * Another fix for broadcast and fix INT64 compilation
    
    * Optimize for aligned=true
    
    * 1 more addition to test
    
    * Reverting the change to Numpy op test
    
    * Trying mcmodel=medium to fix the failure in CMake static build
    
    * Revert "Trying mcmodel=medium to fix the failure in CMake static build"
    
    This reverts commit 1af684c507dd5b2c7ab7ffe89d21799320e3d9c6.
    
    * Limiting the PR to just elementwise ops
---
 3rdparty/mshadow/mshadow/base.h                    |  48 ---
 3rdparty/mshadow/mshadow/half2.h                   | 143 ---------
 src/common/cuda_vectorization.cuh                  | 283 ++++++++++++++++++
 src/operator/mshadow_op.h                          |  66 -----
 src/operator/tensor/elemwise_binary_op.cuh         | 322 +++++++++++++++++++++
 src/operator/tensor/elemwise_binary_op.h           | 206 +++++++------
 src/operator/tensor/elemwise_binary_op_basic.cu    |  23 +-
 src/operator/tensor/elemwise_binary_scalar_op.cuh  | 207 +++++++++++++
 src/operator/tensor/elemwise_binary_scalar_op.h    |  75 ++++-
 .../tensor/elemwise_binary_scalar_op_basic.cu      |   9 +-
 .../tensor/elemwise_binary_scalar_op_extended.cu   |  15 +-
 src/operator/tensor/elemwise_sum.cu                | 112 ++++++-
 src/operator/tensor/elemwise_sum.h                 |  12 -
 src/operator/tensor/elemwise_unary_op.cuh          | 127 ++++++++
 src/operator/tensor/elemwise_unary_op.h            |  56 ++--
 src/operator/tensor/elemwise_unary_op_basic.cu     |   1 +
 src/operator/tensor/elemwise_unary_op_pow.cu       |   1 +
 src/operator/tensor/elemwise_unary_op_trig.cu      |   1 +
 tests/python/unittest/test_operator.py             |  78 +++++
 19 files changed, 1342 insertions(+), 443 deletions(-)

diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h
index 28fbd86..2e658cf 100755
--- a/3rdparty/mshadow/mshadow/base.h
+++ b/3rdparty/mshadow/mshadow/base.h
@@ -276,7 +276,6 @@ extern "C" {
   }
 
 #include "./half.h"
-#include "./half2.h"
 #include "./bfloat.h"
 #define MSHADOW_HALF_BF_OPERATOR(RTYPE, OP)                                    
           \
   MSHADOW_XINLINE RTYPE operator OP(mshadow::half::half_t a, 
mshadow::bfloat::bf16_t b) { \
@@ -391,11 +390,6 @@ struct DataType<half::half_t> {
 #endif
 };
 template<>
-struct DataType<half::half2_t> {
-  static const int kFlag = kFloat16;
-  static const int kLanes = 2;
-};
-template<>
 struct DataType<bfloat::bf16_t> {
   static const int kFlag = kBfloat16;
   static const int kLanes = 1;
@@ -1148,48 +1142,6 @@ struct minimum {
   }
 #endif
 
-#define MSHADOW_TYPE_SWITCH_WITH_HALF2(type, DType, ...)  \
-  switch (type) {                                         \
-  case mshadow::kFloat32:                                 \
-    {                                                     \
-      typedef float DType;                                \
-      {__VA_ARGS__}                                       \
-    }                                                     \
-    break;                                                \
-  case mshadow::kFloat64:                                 \
-    {                                                     \
-      typedef double DType;                               \
-      {__VA_ARGS__}                                       \
-    }                                                     \
-    break;                                                \
-  case mshadow::kFloat16:                                 \
-    {                                                     \
-      typedef mshadow::half::half2_t DType;               \
-      {__VA_ARGS__}                                       \
-    }                                                     \
-    break;                                                \
-  case mshadow::kUint8:                                   \
-    {                                                     \
-      typedef uint8_t DType;                              \
-      {__VA_ARGS__}                                       \
-    }                                                     \
-    break;                                                \
-  case mshadow::kInt32:                                   \
-    {                                                     \
-      typedef int32_t DType;                              \
-      {__VA_ARGS__}                                       \
-    }                                                     \
-    break;                                                \
-  case mshadow::kInt64:                                   \
-    {                                                     \
-      typedef int64_t DType;                              \
-      {__VA_ARGS__}                                       \
-    }                                                     \
-    break;                                                \
-  default:                                                \
-    LOG(FATAL) << "Unknown type enum " << type;           \
-  }
-
 #define MSHADOW_SGL_DBL_TYPE_SWITCH(type, DType, ...)  \
   switch (type) {                                      \
   case mshadow::kFloat32:                              \
diff --git a/3rdparty/mshadow/mshadow/half2.h b/3rdparty/mshadow/mshadow/half2.h
deleted file mode 100755
index 3e130c8..0000000
--- a/3rdparty/mshadow/mshadow/half2.h
+++ /dev/null
@@ -1,143 +0,0 @@
-/*!
- *  Copyright (c) 2017 by Contributors
- * \file half2.h
- * \brief definition of vector float16, half2 type.
- *
- * \author Antti-Pekka Hynninen
- */
-#ifndef MSHADOW_HALF2_H_
-#define MSHADOW_HALF2_H_
-
-#if (defined(__CUDACC__) && __CUDA_ARCH__ >= 530 && MSHADOW_USE_CUDA && 
CUDA_VERSION >= 7050)
-  #define MSHADOW_CUDA_HALF2 1
-  #include <cuda_fp16.h>
-#else
-  #define MSHADOW_CUDA_HALF2 0
-#endif
-
-#include<math.h>
-
-/*! \brief namespace for mshadow */
-namespace mshadow {
-/* \brief name space for host/device portable half-precision floats */
-namespace half {
-
-#define MSHADOW_HALF2_ASSIGNOP(AOP, OP)                                   \
-  template<typename T>                                                    \
-  MSHADOW_XINLINE half2_t operator AOP (const T& a) {                     \
-    return *this = half2_t(*this OP a);  /* NOLINT(*)*/                   \
-  }                                                                       \
-
-class MSHADOW_ALIGNED(4) half2_t {
- public:
-#if MSHADOW_CUDA_HALF2
-  half2 half2_;
-#else
-  half_t half_t2[2];
-#endif
-
-  MSHADOW_XINLINE half2_t() {}
-
-#if MSHADOW_CUDA_HALF2
-  MSHADOW_XINLINE explicit half2_t(half2 a) : half2_(a) {}
-#else
-  MSHADOW_XINLINE explicit half2_t(half_t a, half_t b) {
-    half_t2[0] = a;
-    half_t2[1] = b;
-  }
-#endif
-
-  MSHADOW_XINLINE explicit half2_t(int a) {
-#if MSHADOW_CUDA_HALF2
-    half2_ = __half2half2(__int2half_rz(a));
-#else
-    half_t2[0] = (half_t)a;
-    half_t2[1] = (half_t)a;
-#endif
-  }
-
-  MSHADOW_XINLINE half2_t operator+() {
-    return *this;
-  }
-
-  MSHADOW_XINLINE half2_t operator-() {
-#if MSHADOW_CUDA_HALF2
-    return half2_t(__hneg2(half2_));
-#else
-    return half2_t(-half_t2[0], -half_t2[1]);
-#endif
-  }
-
-  MSHADOW_XINLINE half2_t operator=(const half2_t& a) {
-#if MSHADOW_CUDA_HALF2
-    half2_ = a.half2_;
-#else
-    half_t2[0] = a.half_t2[0];
-    half_t2[1] = a.half_t2[1];
-#endif
-    return a;
-  }
-
-  MSHADOW_HALF2_ASSIGNOP(+=, +)
-  MSHADOW_HALF2_ASSIGNOP(-=, -)
-  MSHADOW_HALF2_ASSIGNOP(*=, *)
-  MSHADOW_HALF2_ASSIGNOP(/=, /)
-};
-
-/*! \brief overloaded + operator for half2_t */
-MSHADOW_XINLINE half2_t operator+(half2_t a, half2_t b) {
-#if MSHADOW_CUDA_HALF2
-  return half2_t(__floats2half2_rn(__low2float(a.half2_) + 
__low2float(b.half2_),
-                                   __high2float(a.half2_) + 
__high2float(b.half2_)));
-#else
-  return half2_t(a.half_t2[0] + b.half_t2[0], a.half_t2[1] + b.half_t2[1]);
-#endif
-}
-/*! \brief overloaded - operator for half2_t */
-MSHADOW_XINLINE half2_t operator-(half2_t a, half2_t b) {
-#if MSHADOW_CUDA_HALF2
-  return half2_t(__floats2half2_rn(__low2float(a.half2_) - 
__low2float(b.half2_),
-                                   __high2float(a.half2_) - 
__high2float(b.half2_)));
-#else
-  return half2_t(a.half_t2[0] - b.half_t2[0], a.half_t2[1] - b.half_t2[1]);
-#endif
-}
-/*! \brief overloaded * operator for half2_t */
-MSHADOW_XINLINE half2_t operator*(half2_t a, half2_t b) {
-#if MSHADOW_CUDA_HALF2
-  return half2_t(__floats2half2_rn(__low2float(a.half2_) * 
__low2float(b.half2_),
-                                   __high2float(a.half2_) * 
__high2float(b.half2_)));
-#else
-  return half2_t(a.half_t2[0] * b.half_t2[0], a.half_t2[1] * b.half_t2[1]);
-#endif
-}
-/*! \brief overloaded / operator for half2_t */
-MSHADOW_XINLINE half2_t operator/(half2_t a, half2_t b) {
-#if MSHADOW_CUDA_HALF2
-  return half2_t(__floats2half2_rn(__low2float(a.half2_) / 
__low2float(b.half2_),
-                                   __high2float(a.half2_) / 
__high2float(b.half2_)));
-#else
-  return half2_t(a.half_t2[0] / b.half_t2[0], a.half_t2[1] / b.half_t2[1]);
-#endif
-}
-/*! \brief overloaded % operator for half2_t */
-MSHADOW_XINLINE half2_t operator%(half2_t a, half2_t b) {
-#if MSHADOW_CUDA_HALF2
-  return half2_t(__floats2half2_rn(::fmod(__low2float(a.half2_), 
__low2float(b.half2_)),
-                                   ::fmod(__high2float(a.half2_), 
__high2float(b.half2_))));
-#else
-  return half2_t(::fmod(a.half_t2[0], b.half_t2[0]), ::fmod(a.half_t2[1], 
b.half_t2[1]));
-#endif
-}
-/*! \brief overloaded == operator for half2_t */
-MSHADOW_XINLINE bool operator==(half2_t a, half2_t b) {
-#if MSHADOW_CUDA_HALF2
-  return __hbeq2(a.half2_, b.half2_);
-#else
-  return (a.half_t2[0] == b.half_t2[0] && a.half_t2[1] == b.half_t2[1]);
-#endif
-}
-
-}  // namespace half
-}  // namespace mshadow
-#endif  // MSHADOW_HALF2_H_
diff --git a/src/common/cuda_vectorization.cuh 
b/src/common/cuda_vectorization.cuh
new file mode 100644
index 0000000..7803afb
--- /dev/null
+++ b/src/common/cuda_vectorization.cuh
@@ -0,0 +1,283 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2020 by Contributors
+ * \file cuda_vectorization.cuh
+ * \brief GPU helpers for vectorized memory accesses
+ */
+
+#ifndef MXNET_COMMON_CUDA_VECTORIZATION_CUH_
+#define MXNET_COMMON_CUDA_VECTORIZATION_CUH_
+
+#if MXNET_USE_CUDA && __CUDACC__
+
+#include <cuda_runtime.h>
+#include "cuda_utils.h"
+
+
+namespace mxnet {
+namespace common {
+namespace cuda {
+
+/* \brief Helper class that enables storing multiple values of type DType
+          as 1 value of type LType.
+*/
+template <typename DType, typename LType>
+class VectorizedStorage {
+ public:
+  constexpr static int nvec = sizeof(LType) / sizeof(DType);
+  union vectorized_storage {
+    LType aligned;
+    DType separate[nvec];  // NOLINT(*)
+
+    MSHADOW_XINLINE vectorized_storage() {}
+    MSHADOW_XINLINE ~vectorized_storage() {}
+  } scratch_;
+};
+
+/* \brief Helper class that enables accessing multiple values of type DType
+          as 1 value of type LType. Additional aligned template argument
+          allows performance optimizations if the pointer and the size of
+          the allocation is aligned to sizeof(LType) / sizeof(DType) elements.
+*/
+template <typename DType, typename LType, bool aligned = false>
+class VectorizedAccessor {
+ public:
+  using StorageType = VectorizedStorage<typename 
std::remove_const<DType>::type,
+                                        typename 
std::remove_const<LType>::type>;
+  StorageType storage_;
+
+  LType* aligned_ptr_;
+  DType* unaligned_ptr_;
+  int alignment_;
+  index_t n_elems_;
+
+  MSHADOW_XINLINE VectorizedAccessor(DType* ptr, const index_t size) {
+    unaligned_ptr_ = ptr;
+    if (aligned) {
+      alignment_ = 0;
+      aligned_ptr_ = reinterpret_cast<LType*>(ptr);
+      n_elems_ = (size + storage_.nvec - 1) / storage_.nvec;
+    } else {
+      size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
+      alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType);
+      aligned_ptr_ = reinterpret_cast<LType*>(ptr - alignment_);
+      n_elems_ = (size + alignment_ + storage_.nvec - 1) / storage_.nvec;
+    }
+  }
+
+  /* \brief Alignment of the input pointer in elements. */
+  MSHADOW_XINLINE int alignment() const {
+    return alignment_;
+  }
+
+  /* \brief Access to separate elements. */
+  MSHADOW_XINLINE DType* separate() {
+    return storage_.scratch_.separate;
+  }
+
+  /* \brief Number of elements stored. */
+  MSHADOW_XINLINE constexpr int nvec() const {
+    return storage_.nvec;
+  }
+
+  /* \brief Number of aligned elements that span the entire input tensor. */
+  MSHADOW_XINLINE index_t num_aligned_elements() const {
+    return n_elems_;
+  }
+
+  /* \brief Load values from the input.
+     \param id Aligned index of the element.
+     \param N size of the tensor.
+  */
+  MSHADOW_XINLINE void load(const index_t id, const index_t N) {
+    if (aligned) {
+      storage_.scratch_.aligned = aligned_ptr_[id];
+    } else {
+      if (id > 0 && id < n_elems_ - 1) {
+        storage_.scratch_.aligned = aligned_ptr_[id];
+      } else {
+#pragma unroll
+        for (int j = 0; j < storage_.nvec; ++j) {
+          DType* ptr = reinterpret_cast<DType*>(&(aligned_ptr_[id])) + j;
+          if (reinterpret_cast<size_t>(ptr) >= 
reinterpret_cast<size_t>(unaligned_ptr_) &&
+              reinterpret_cast<size_t>(ptr) < 
reinterpret_cast<size_t>(unaligned_ptr_ + N)) {
+            storage_.scratch_.separate[j] = *ptr;
+          }
+        }
+      }
+    }
+  }
+};
+
+/* \brief Class used for vectorized read-only access. */
+template <typename DType, typename LType, bool aligned = false>
+class VectorizedLoader : public VectorizedAccessor<const DType, const LType, 
aligned> {
+ public:
+  MSHADOW_XINLINE VectorizedLoader(const DType* ptr, const index_t N) :
+    VectorizedAccessor<const DType, const LType, aligned>(ptr, N) {
+  }
+};
+
+/* \brief Class used for vectorized writable access. */
+template <typename DType, typename LType, bool aligned = false>
+class VectorizedStorer : public VectorizedAccessor<DType, LType, aligned> {
+ public:
+  MSHADOW_XINLINE VectorizedStorer(DType* ptr, const index_t N) :
+    VectorizedAccessor<DType, LType, aligned>(ptr, N) {
+  }
+
+  /* \brief Store values to the output.
+     \param id Aligned index of the element.
+     \param N size of the tensor.
+  */
+  MSHADOW_XINLINE void store(const index_t id, const index_t N) {
+    if (aligned) {
+      this->aligned_ptr_[id] = this->storage_.scratch_.aligned;
+    } else {
+      if (id > 0 && id < this->n_elems_ - 1) {
+        this->aligned_ptr_[id] = this->storage_.scratch_.aligned;
+      } else {
+#pragma unroll
+        for (int j = 0; j < this->storage_.nvec; ++j) {
+          DType* ptr = reinterpret_cast<DType*>(&(this->aligned_ptr_[id])) + j;
+          if (reinterpret_cast<size_t>(ptr) >= 
reinterpret_cast<size_t>(this->unaligned_ptr_) &&
+              reinterpret_cast<size_t>(ptr) < 
reinterpret_cast<size_t>(this->unaligned_ptr_ + N)) {
+            *ptr = this->storage_.scratch_.separate[j];
+          }
+        }
+      }
+    }
+  }
+};
+
+namespace {
+
+enum class Alignment {
+  SAME_ALIGNED,  // All tensors aligned
+  SAME_UNALIGNED,  // All tensors have the same misalignment
+  DIFFERENT  // Tensors have different alignment
+};
+
+template <typename LType, typename DType>
+int CalcAlignment(const DType* ptr) {
+  size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
+  return ptr_as_number % sizeof(LType);
+}
+
+/* \brief Check alignment of the inputs and outputs when cast to LType*.
+   \param params Structuce containing arrays with inputs' and outputs' pointers
+   \param lead_dim Leading dimension of the tensors.
+   \param other_dim The size of the other dimensions of the tensors.
+*/
+template <typename LType, typename DType, typename Params>
+Alignment CheckAlignment(const Params& params, const index_t lead_dim, const 
index_t other_dim) {
+  int align = -1;
+  constexpr int nvec = sizeof(LType) / sizeof(DType);
+
+  for (const DType* ptr : params.inputs) {
+    int new_align = CalcAlignment<LType>(ptr);
+    if (align == -1) {
+      align = new_align;
+    } else {
+      if (align != new_align) {
+        return Alignment::DIFFERENT;
+      }
+    }
+  }
+
+  for (const DType* ptr : params.outputs) {
+    int new_align = CalcAlignment<LType>(ptr);
+    if (align == -1) {
+      align = new_align;
+    } else {
+      if (align != new_align) {
+        return Alignment::DIFFERENT;
+      }
+    }
+  }
+
+  if ((other_dim != 1) &&
+      (lead_dim % nvec != 0)) {
+    return Alignment::DIFFERENT;
+  }
+
+  if ((align == 0) &&
+      (lead_dim % nvec == 0)) {
+    return Alignment::SAME_ALIGNED;
+  } else {
+    return Alignment::SAME_UNALIGNED;
+  }
+}
+
+constexpr int vectorized_kernel_thread_num = 512;
+
+}  // namespace
+
+/* \brief Helper launcher function for the vectorized kernels. Checks for 
alignment of the
+          input and output tensors and launches a proper template.
+   \param lead_dim Leading dimension of the tensors.
+   \param other_dim The size of the other dimensions.
+   \param s Stream which should be used for launching the kernel.
+   \param params Input parameters to the kernel. Needs to contain at least 2 
arrays of DType*:
+                 inputs and outputs, which contain input and output pointers.
+*/
+template <typename DType, typename LType, typename Kernel>
+void VectorizedKernelLauncher(const index_t lead_dim,
+                              const index_t other_dim,
+                              mshadow::Stream<gpu>* s,
+                              typename Kernel::ParamType params) {
+  static_assert(sizeof(LType) >= sizeof(DType), "Load type is smaller than 
operand type");
+  if (lead_dim * other_dim != 0) {
+    cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+    VectorizedLoader<DType, LType> l(params.inputs[0], lead_dim);
+    size_t num_elements = other_dim * l.num_aligned_elements();
+    constexpr int threads = vectorized_kernel_thread_num;
+    constexpr int max_blocks = 65535;
+    index_t blocks = std::min(static_cast<int>((num_elements + threads - 1) / 
threads),
+                              max_blocks);
+    auto align = CheckAlignment<LType, DType>(params, lead_dim, other_dim);
+    switch (align) {
+      case Alignment::SAME_ALIGNED:
+        Kernel::template Launch<true, LType>(blocks, threads, stream, params, 
lead_dim, other_dim);
+        break;
+      case Alignment::SAME_UNALIGNED:
+        Kernel::template Launch<false, LType>(blocks, threads, stream, params, 
lead_dim, other_dim);
+        break;
+      case Alignment::DIFFERENT: {
+        const index_t size = lead_dim * other_dim;
+        index_t blocks = std::min(static_cast<int>((size + threads - 1) /
+                                                   threads),
+                                  max_blocks);
+        // If the pointers are aligned differently we cannot vectorize
+        Kernel::template Launch<true, DType>(blocks, threads, stream, params, 
lead_dim, other_dim);
+        break;
+      }
+    }
+  }
+}
+
+}  // namespace cuda
+}  // namespace common
+}  // namespace mxnet
+
+#endif  // MXNET_USE_CUDA && __CUDACC__
+
+#endif  // MXNET_COMMON_CUDA_VECTORIZATION_CUH_
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 2d2a0de..e0bbb4e 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -730,22 +730,8 @@ MXNET_BINARY_MATH_OP(rminus, b - a);
 
 MXNET_BINARY_MATH_OP(div_grad, 1.0f / math::id(b));
 
-template<>
-MSHADOW_XINLINE mshadow::half::half2_t div_grad::Map<mshadow::half::half2_t>
-                                               (mshadow::half::half2_t a,
-                                                mshadow::half::half2_t b) {
-  return mshadow::half::half2_t(1) / b;
-}
-
 MXNET_BINARY_MATH_OP(div_rgrad, -math::id(a) / math::sqr(b));
 
-template<>
-MSHADOW_XINLINE mshadow::half::half2_t div_rgrad::Map<mshadow::half::half2_t>
-                                               (mshadow::half::half2_t a,
-                                                mshadow::half::half2_t b) {
-  return -a / (b * b);
-}
-
 MXNET_BINARY_MATH_OP(rdiv, math::id(b) / math::id(a));
 
 MXNET_BINARY_MATH_OP(rdiv_grad, -math::id(b) / math::sqr(a));
@@ -795,12 +781,6 @@ struct mod : public mxnet_op::tunable {
   }
 };
 
-template<>
-MSHADOW_XINLINE mshadow::half::half2_t mod::Map<mshadow::half::half2_t>
-                                               (mshadow::half::half2_t a,
-                                                mshadow::half::half2_t b) {
-  return a%b;
-}
 
 struct mod_grad : public mxnet_op::tunable  {
   template<typename DType>
@@ -823,19 +803,6 @@ MSHADOW_XINLINE mshadow::half::half_t 
mod_grad::Map<mshadow::half::half_t>
                                                     mshadow::half::half_t b) {
   return mshadow::half::half_t(1.0f);
 }
-template<>
-MSHADOW_XINLINE mshadow::half::half2_t mod_grad::Map<mshadow::half::half2_t>
-                                                    (mshadow::half::half2_t a,
-                                                     mshadow::half::half2_t b) 
{
-  mshadow::half::half2_t result = mshadow::half::half2_t();
-#if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2)
-  result.half2_ = ::__float2half2_rn(1.0f);
-#else
-  result.half_t2[0] = mshadow::half::half_t(0.0f);
-  result.half_t2[1] = mshadow::half::half_t(1.0f);
-#endif
-  return result;
-}
 
 struct mod_rgrad : public mxnet_op::tunable {
   template<typename DType>
@@ -858,19 +825,6 @@ MSHADOW_XINLINE mshadow::half::half_t 
mod_rgrad::Map<mshadow::half::half_t>
                                                      mshadow::half::half_t b) {
   return mshadow::half::half_t(-::floorf(static_cast<float>(a/b)));
 }
-template<>
-MSHADOW_XINLINE mshadow::half::half2_t mod_rgrad::Map<mshadow::half::half2_t>
-                                                     (mshadow::half::half2_t a,
-                                                      mshadow::half::half2_t 
b) {
-#if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2)
-  return mshadow::half::half2_t(__hneg2(::h2floor((a/b).half2_)));
-#else
-  return mshadow::half::half2_t(mshadow::half::half_t(-::floorf(
-                                  
static_cast<float>(a.half_t2[0]/b.half_t2[0]))),
-                                mshadow::half::half_t(-::floorf(
-                                  
static_cast<float>(a.half_t2[1]/b.half_t2[1]))));
-#endif
-}
 
 struct rmod : public mxnet_op::tunable {
   template<typename DType>
@@ -907,13 +861,6 @@ struct rmod : public mxnet_op::tunable {
   }
 };
 
-template<>
-MSHADOW_XINLINE mshadow::half::half2_t rmod::Map<mshadow::half::half2_t>
-                                                (mshadow::half::half2_t a,
-                                                 mshadow::half::half2_t b) {
-  return b%a;
-}
-
 struct rmod_grad {
   template<typename DType>
   MSHADOW_XINLINE static DType Map(DType a, DType b) {
@@ -935,19 +882,6 @@ MSHADOW_XINLINE mshadow::half::half_t 
rmod_grad::Map<mshadow::half::half_t>
                                                     mshadow::half::half_t b) {
   return mshadow::half::half_t(-::floorf(static_cast<float>(b/a)));
 }
-template<>
-MSHADOW_XINLINE mshadow::half::half2_t rmod_grad::Map<mshadow::half::half2_t>
-                                                     (mshadow::half::half2_t a,
-                                                      mshadow::half::half2_t 
b) {
-#if (defined(__CUDACC__) && MSHADOW_CUDA_HALF2)
-  return mshadow::half::half2_t(::__hneg2(::h2floor((b/a).half2_)));
-#else
-  return mshadow::half::half2_t(mshadow::half::half_t(-::floorf(
-                                  
static_cast<float>(b.half_t2[0]/a.half_t2[0]))),
-                                mshadow::half::half_t(-::floorf(
-                                  
static_cast<float>(b.half_t2[1]/a.half_t2[1]))));
-#endif
-}
 
 struct clip : public mxnet_op::tunable {
   template<typename DType>
diff --git a/src/operator/tensor/elemwise_binary_op.cuh 
b/src/operator/tensor/elemwise_binary_op.cuh
new file mode 100644
index 0000000..0bb9fa6
--- /dev/null
+++ b/src/operator/tensor/elemwise_binary_op.cuh
@@ -0,0 +1,322 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2020 by Contributors
+ * \file elemwise_binary_op.cuh
+ * \brief GPU helpers for elementwise operators
+ */
+
+#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_CUH_
+#define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_CUH_
+
+#include <cuda_runtime.h>
+#include "../operator_common.h"
+#include "../../common/cuda_vectorization.cuh"
+
+#include <vector>
+
+#if MXNET_USE_CUDA
+
+namespace mxnet {
+namespace op {
+
+namespace binary {
+
+using common::cuda::VectorizedKernelLauncher;
+using common::cuda::VectorizedLoader;
+using common::cuda::VectorizedStorer;
+
+template <typename DType, int NumInputs, int NumOutputs>
+struct VectorizedBinaryKernelParams {
+  const DType* inputs[NumInputs];
+  DType* outputs[NumOutputs];
+};
+
+template <bool aligned, typename DType, typename LType, typename OP, int req>
+__global__ void VectorizedBinaryKernelFwd(const 
VectorizedBinaryKernelParams<DType, 2, 1> params,
+                                          const index_t N) {
+  VectorizedLoader<DType, LType, aligned> loader0(params.inputs[0], N);
+  VectorizedLoader<DType, LType, aligned> loader1(params.inputs[1], N);
+  VectorizedStorer<DType, LType, aligned> storer(params.outputs[0], N);
+
+  const index_t M = loader0.num_aligned_elements();
+
+  for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
+       tid < M;
+       tid += gridDim.x * blockDim.x) {
+    loader0.load(tid, N);
+    loader1.load(tid, N);
+    if (req == kAddTo) {
+      storer.load(tid, N);
+    }
+#pragma unroll
+    for (int i = 0; i < loader0.nvec(); ++i) {
+      DType temp = OP::Map(loader0.separate()[i],
+                           loader1.separate()[i]);
+
+      if (req == kAddTo) {
+        storer.separate()[i] += temp;
+      } else {
+        storer.separate()[i] = temp;
+      }
+    }
+    storer.store(tid, N);
+  }
+}
+
+template <bool aligned, typename DType, typename LType,
+          typename LOP, typename ROP, int lreq, int rreq>
+__global__ void VectorizedBinaryKernelBwdUseNone(
+    const VectorizedBinaryKernelParams<DType, 1, 2> params,
+    const index_t N) {
+  VectorizedLoader<DType, LType, aligned> loader(params.inputs[0], N);
+  VectorizedStorer<DType, LType, aligned> lstorer(params.outputs[0], N);
+  VectorizedStorer<DType, LType, aligned> rstorer(params.outputs[1], N);
+
+  const index_t M = loader.num_aligned_elements();
+
+  for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
+       tid < M;
+       tid += gridDim.x * blockDim.x) {
+    loader.load(tid, N);
+    if (lreq == kAddTo) {
+      lstorer.load(tid, N);
+    }
+    if (rreq == kAddTo) {
+      rstorer.load(tid, N);
+    }
+#pragma unroll
+    for (int i = 0; i < loader.nvec(); ++i) {
+      DType inp = loader.separate()[i];
+      if (!((std::is_same<LOP, mshadow_op::identity>::value && lreq == 
kWriteInplace) ||
+            lreq == kNullOp)) {
+        DType ltemp = LOP::Map(inp);
+        if (lreq == kAddTo) {
+          lstorer.separate()[i] += ltemp;
+        } else {
+          lstorer.separate()[i] = ltemp;
+        }
+        lstorer.store(tid, N);
+      }
+      if (!((std::is_same<ROP, mshadow_op::identity>::value && rreq == 
kWriteInplace) ||
+            rreq == kNullOp)) {
+        DType rtemp = ROP::Map(inp);
+
+        if (rreq == kAddTo) {
+          rstorer.separate()[i] += rtemp;
+        } else {
+          rstorer.separate()[i] = rtemp;
+        }
+        rstorer.store(tid, N);
+      }
+    }
+  }
+}
+
+template <bool aligned, typename DType, typename LType,
+          typename LOP, typename ROP, int lreq, int rreq>
+__global__ void VectorizedBinaryKernelBwdUseIn(
+    const VectorizedBinaryKernelParams<DType, 3, 2> params,
+    const index_t N) {
+  VectorizedLoader<DType, LType, aligned> ograd_loader(params.inputs[0], N);
+  VectorizedLoader<DType, LType, aligned> linput_loader(params.inputs[1], N);
+  VectorizedLoader<DType, LType, aligned> rinput_loader(params.inputs[2], N);
+  VectorizedStorer<DType, LType, aligned> lstorer(params.outputs[0], N);
+  VectorizedStorer<DType, LType, aligned> rstorer(params.outputs[1], N);
+
+  const index_t M = ograd_loader.num_aligned_elements();
+
+  for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
+       tid < M;
+       tid += gridDim.x * blockDim.x) {
+    ograd_loader.load(tid, N);
+    linput_loader.load(tid, N);
+    rinput_loader.load(tid, N);
+    if (lreq == kAddTo) {
+      lstorer.load(tid, N);
+    }
+    if (rreq == kAddTo) {
+      rstorer.load(tid, N);
+    }
+#pragma unroll
+    for (int i = 0; i < ograd_loader.nvec(); ++i) {
+      DType ograd = ograd_loader.separate()[i];
+      DType linput = linput_loader.separate()[i];
+      DType rinput = rinput_loader.separate()[i];
+      if (!(lreq == kNullOp)) {
+        DType ltemp = ograd * LOP::Map(linput, rinput);
+        if (lreq == kAddTo) {
+          lstorer.separate()[i] += ltemp;
+        } else {
+          lstorer.separate()[i] = ltemp;
+        }
+        lstorer.store(tid, N);
+      }
+      if (!(rreq == kNullOp)) {
+        DType rtemp = ograd * ROP::Map(linput, rinput);
+
+        if (rreq == kAddTo) {
+          rstorer.separate()[i] += rtemp;
+        } else {
+          rstorer.separate()[i] = rtemp;
+        }
+        rstorer.store(tid, N);
+      }
+    }
+  }
+}
+
+template <typename DType, typename OP, int req>
+class VectorizedBinaryFwd {
+ public:
+  using ParamType = VectorizedBinaryKernelParams<DType, 2, 1>;
+
+  template <bool aligned, typename LType>
+  static void Launch(const index_t blocks, const index_t threads,
+                     cudaStream_t stream,
+                     const ParamType params, const index_t lead_dim,
+                     const index_t /* other_dim */) {
+    VectorizedBinaryKernelFwd<aligned, DType, LType, OP, req>
+      <<<blocks, threads, 0, stream>>>(params, lead_dim);
+  }
+};
+
+template <typename DType, typename LOP, typename ROP, int lreq, int rreq>
+class VectorizedBinaryBwdUseNone {
+ public:
+  using ParamType = VectorizedBinaryKernelParams<DType, 1, 2>;
+
+  template <bool aligned, typename LType>
+  static void Launch(const index_t blocks, const index_t threads,
+                     cudaStream_t stream,
+                     const ParamType params, const index_t lead_dim,
+                     const index_t /* other_dim */) {
+    VectorizedBinaryKernelBwdUseNone<aligned, DType, LType, LOP, ROP, lreq, 
rreq>
+      <<<blocks, threads, 0, stream>>>(params, lead_dim);
+  }
+};
+
+template <typename DType, typename LOP, typename ROP, int lreq, int rreq>
+class VectorizedBinaryBwdUseIn {
+ public:
+  using ParamType = VectorizedBinaryKernelParams<DType, 3, 2>;
+
+  template <bool aligned, typename LType>
+  static void Launch(const index_t blocks, const index_t threads,
+                     cudaStream_t stream,
+                     const ParamType params, const index_t lead_dim,
+                     const index_t /* other_dim */) {
+    VectorizedBinaryKernelBwdUseIn<aligned, DType, LType, LOP, ROP, lreq, rreq>
+      <<<blocks, threads, 0, stream>>>(params, lead_dim);
+  }
+};
+
+}  // namespace binary
+
+template<typename OP>
+void ElemwiseBinaryOp::Compute_(const nnvm::NodeAttrs &attrs,
+                                mshadow::Stream<gpu> *s,
+                                const std::vector<TBlob> &inputs,
+                                const std::vector<OpReqType> &req,
+                                const std::vector<TBlob> &outputs) {
+  using namespace binary;
+  if (req[0] == kNullOp) return;
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      using LType = uint4;
+      using Kernel = VectorizedBinaryFwd<DType, OP, Req>;
+
+      const index_t size = outputs[0].Size();
+      typename Kernel::ParamType params;
+      params.inputs[0] = inputs[0].dptr<DType>();
+      params.inputs[1] = inputs[1].dptr<DType>();
+      params.outputs[0] = outputs[0].dptr<DType>();
+
+      VectorizedKernelLauncher<DType, LType, Kernel>(size, 1, s, params);
+    });
+  });
+}
+
+template<typename LOP, typename ROP>
+void ElemwiseBinaryOp::BackwardUseNone_(const nnvm::NodeAttrs &attrs,
+                                        mshadow::Stream<gpu>* s,
+                                        const std::vector<TBlob> &inputs,
+                                        const std::vector<OpReqType> &req,
+                                        const std::vector<TBlob> &outputs) {
+  using namespace binary;
+  cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
+
+  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+    const index_t size = inputs[0].Size();
+    if (req[0] != kNullOp || req[1] != kNullOp) {
+      MXNET_REQ_TYPE_SWITCH(req[0], lreq, {
+        MXNET_REQ_TYPE_SWITCH(req[1], rreq, {
+          using LType = uint4;
+          using Kernel = VectorizedBinaryBwdUseNone<DType, LOP, ROP, lreq, 
rreq>;
+
+          typename Kernel::ParamType params;
+          params.inputs[0] = inputs[0].dptr<DType>();
+          params.outputs[0] = outputs[0].dptr<DType>();
+          params.outputs[1] = outputs[1].dptr<DType>();
+
+          VectorizedKernelLauncher<DType, LType, Kernel>(size, 1, s, params);
+        });
+      });
+    }
+  });
+}
+
+template<typename LOP, typename ROP>
+void ElemwiseBinaryOp::BackwardUseIn_(const nnvm::NodeAttrs &attrs,
+                                      mshadow::Stream<gpu>* s,
+                                      const std::vector<TBlob> &inputs,
+                                      const std::vector<OpReqType> &req,
+                                      const std::vector<TBlob> &outputs) {
+  using namespace binary;
+  if (req[0] != kNullOp || req[1] != kNullOp) {
+    MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
+      MXNET_REQ_TYPE_SWITCH(req[0], lreq, {
+        MXNET_REQ_TYPE_SWITCH(req[1], rreq, {
+          const index_t size = inputs[0].Size();
+          // Using 64 bit loads to reduce register pressure
+          using LType = uint2;
+          using Kernel = VectorizedBinaryBwdUseIn<DType, LOP, ROP, lreq, rreq>;
+
+          typename Kernel::ParamType params;
+          params.inputs[0] = inputs[0].dptr<DType>();
+          params.inputs[1] = inputs[1].dptr<DType>();
+          params.inputs[2] = inputs[2].dptr<DType>();
+          params.outputs[0] = outputs[0].dptr<DType>();
+          params.outputs[1] = outputs[1].dptr<DType>();
+
+          VectorizedKernelLauncher<DType, LType, Kernel>(size, 1, s, params);
+        });
+      });
+    });
+  }
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_CUDA
+#endif  // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_CUH_
diff --git a/src/operator/tensor/elemwise_binary_op.h 
b/src/operator/tensor/elemwise_binary_op.h
index bc5140a..b9396ae 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -106,62 +106,85 @@ class ElemwiseBinaryOp : public OpBase {
   }
 
  private:
-  template<typename xpu, typename LOP, typename ROP, typename DType>
+  template<typename LOP, typename ROP>
   static void BackwardUseNone_(const nnvm::NodeAttrs &attrs,
-                               const OpContext &ctx,
+                               mshadow::Stream<cpu>* s,
                                const std::vector<TBlob> &inputs,
                                const std::vector<OpReqType> &req,
                                const std::vector<TBlob> &outputs) {
-    using namespace mxnet_op;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    const int size = static_cast<int>((outputs[0].Size() + 
DataType<DType>::kLanes - 1)
-                                      / DataType<DType>::kLanes);
-    const DType *ograd_dptr = inputs[0].dptr<DType>();
-    if (std::is_same<LOP, mshadow_op::identity>::value && req[0] == 
kWriteInplace) {
-      CHECK_EQ(ograd_dptr, outputs[0].dptr<DType>());
-    } else if (req[0] != kNullOp) {
-      DType *lgrad_dptr = outputs[0].dptr<DType>();
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      using namespace mxnet_op;
+      const int size = static_cast<int>((outputs[0].Size() + 
DataType<DType>::kLanes - 1)
+                                        / DataType<DType>::kLanes);
+      const DType *ograd_dptr = inputs[0].dptr<DType>();
+      if (std::is_same<LOP, mshadow_op::identity>::value && req[0] == 
kWriteInplace) {
+        CHECK_EQ(ograd_dptr, outputs[0].dptr<DType>());
+      } else if (req[0] != kNullOp) {
+        DType *lgrad_dptr = outputs[0].dptr<DType>();
+        MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+          Kernel<mxnet_op::op_with_req<LOP, Req>, cpu>::Launch(s, size, 
lgrad_dptr, ograd_dptr);
+        });
+      }
+      if (std::is_same<ROP, mshadow_op::identity>::value && req[1] == 
kWriteInplace) {
+        CHECK_EQ(ograd_dptr, outputs[1].dptr<DType>());
+      } else if (req[1] != kNullOp) {
+        DType *rgrad_dptr = outputs[1].dptr<DType>();
+        MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
+          Kernel<mxnet_op::op_with_req<ROP, Req>, cpu>::Launch(s, size, 
rgrad_dptr, ograd_dptr);
+        });
+      }
+    });
+  }
+#if MXNET_USE_CUDA
+  template<typename LOP, typename ROP>
+  static void BackwardUseNone_(const nnvm::NodeAttrs &attrs,
+                               mshadow::Stream<gpu>* s,
+                               const std::vector<TBlob> &inputs,
+                               const std::vector<OpReqType> &req,
+                               const std::vector<TBlob> &outputs);
+#endif
+
+  template<typename LOP, typename ROP>
+  static void BackwardUseIn_(const nnvm::NodeAttrs &attrs,
+                             mshadow::Stream<cpu>* s,
+                             const std::vector<TBlob> &inputs,
+                             const std::vector<OpReqType> &req,
+                             const std::vector<TBlob> &outputs) {
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      DCHECK_EQ(outputs.size(), 2U);
+      DCHECK_EQ(inputs.size(), 3U);
+      const DType *ograd_dptr = inputs[0].dptr<DType>();
+      const DType *lhs_dptr = inputs[1].dptr<DType>();
+      const DType *rhs_dptr = inputs[2].dptr<DType>();
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-        Kernel<mxnet_op::op_with_req<LOP, Req>, xpu>::Launch(s, size, 
lgrad_dptr, ograd_dptr);
+        const int size = static_cast<int>(
+          (outputs[0].Size() + mxnet_op::DataType<DType>::kLanes - 1)
+          / mxnet_op::DataType<DType>::kLanes);
+        DType * lgrad_dptr = outputs[0].dptr<DType>();
+        mxnet_op::Kernel<
+          mxnet_op::op_with_req<mxnet_op::backward_grad_tuned<LOP>, Req>, 
cpu>::Launch(
+            s, size, lgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);
       });
-    }
-    if (std::is_same<ROP, mshadow_op::identity>::value && req[1] == 
kWriteInplace) {
-      CHECK_EQ(ograd_dptr, outputs[1].dptr<DType>());
-    } else if (req[1] != kNullOp) {
-      DType *rgrad_dptr = outputs[1].dptr<DType>();
       MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
-        Kernel<mxnet_op::op_with_req<ROP, Req>, xpu>::Launch(s, size, 
rgrad_dptr, ograd_dptr);
+        const int size = static_cast<int>(
+          (outputs[1].Size() + mxnet_op::DataType<DType>::kLanes - 1)
+          / mxnet_op::DataType<DType>::kLanes);
+        DType * rgrad_dptr = outputs[1].dptr<DType>();
+        mxnet_op::Kernel<
+          mxnet_op::op_with_req<mxnet_op::backward_grad_tuned<ROP>, Req>, 
cpu>::Launch(
+            s, size, rgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);
       });
-    }
+    });
   }
 
-  template<typename xpu, typename LOP, typename ROP, typename DType>
+#if MXNET_USE_CUDA
+  template<typename LOP, typename ROP>
   static void BackwardUseIn_(const nnvm::NodeAttrs &attrs,
-                             const OpContext &ctx,
+                             mshadow::Stream<gpu>* s,
                              const std::vector<TBlob> &inputs,
                              const std::vector<OpReqType> &req,
-                             const std::vector<TBlob> &outputs) {
-    DCHECK_EQ(outputs.size(), 2U);
-    DCHECK_EQ(inputs.size(), 3U);
-    mxnet_op::Stream<xpu> *s = ctx.get_stream<xpu>();
-    const DType *ograd_dptr = inputs[0].dptr<DType>();
-    const DType *lhs_dptr = inputs[1].dptr<DType>();
-    const DType *rhs_dptr = inputs[2].dptr<DType>();
-    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-      const int size = static_cast<int>(
-        (outputs[0].Size() + mxnet_op::DataType<DType>::kLanes - 1)
-        / mxnet_op::DataType<DType>::kLanes);
-      DType * lgrad_dptr = outputs[0].dptr<DType>();
-      
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::backward_grad_tuned<LOP>, 
Req>, xpu>::Launch(
-        s, size, lgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);});
-    MXNET_ASSIGN_REQ_SWITCH(req[1], Req, {
-      const int size = static_cast<int>(
-        (outputs[1].Size() + mxnet_op::DataType<DType>::kLanes - 1)
-        / mxnet_op::DataType<DType>::kLanes);
-      DType * rgrad_dptr = outputs[1].dptr<DType>();
-      
mxnet_op::Kernel<mxnet_op::op_with_req<mxnet_op::backward_grad_tuned<ROP>, 
Req>, xpu>::Launch(
-        s, size, rgrad_dptr, ograd_dptr, lhs_dptr, rhs_dptr);});
-  }
+                             const std::vector<TBlob> &outputs);
+#endif
 
   template<
     typename xpu,
@@ -498,15 +521,13 @@ class ElemwiseBinaryOp : public OpBase {
     });
   }
 
-  template<typename xpu, typename OP>
-  static void Compute(const nnvm::NodeAttrs &attrs,
-                      const OpContext &ctx,
-                      const std::vector<TBlob> &inputs,
-                      const std::vector<OpReqType> &req,
-                      const std::vector<TBlob> &outputs) {
+  template<typename OP>
+  static void Compute_(const nnvm::NodeAttrs &attrs,
+                       mshadow::Stream<cpu> *s,
+                       const std::vector<TBlob> &inputs,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &outputs) {
     using namespace mxnet_op;
-    if (req[0] == kNullOp) return;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
     CHECK_EQ(inputs.size(), 2U);
     CHECK_EQ(outputs.size(), 1U);
     if (outputs[0].type_flag_ == mshadow::kBool) {
@@ -517,7 +538,7 @@ class ElemwiseBinaryOp : public OpBase {
         const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), 
inputs[1].Size())
         + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
         if (size != 0) {
-          Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
+          Kernel<mxnet_op::op_with_req<OP, Req>, cpu>::Launch(s, size,
           outputs[0].dptr<DType>(),
           inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
         }
@@ -525,6 +546,26 @@ class ElemwiseBinaryOp : public OpBase {
     });
   }
 
+#if MXNET_USE_CUDA
+  template<typename OP>
+  static void Compute_(const nnvm::NodeAttrs &attrs,
+                       mshadow::Stream<gpu> *s,
+                       const std::vector<TBlob> &inputs,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &outputs);
+#endif
+
+  template<typename xpu, typename OP>
+  static void Compute(const nnvm::NodeAttrs &attrs,
+                      const OpContext &ctx,
+                      const std::vector<TBlob> &inputs,
+                      const std::vector<OpReqType> &req,
+                      const std::vector<TBlob> &outputs) {
+    if (req[0] == kNullOp) return;
+    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+    Compute_<OP>(attrs, s, inputs, req, outputs);
+  }
+
   template<typename xpu, typename OP>
   static void ComputeWithBool(const nnvm::NodeAttrs &attrs,
                               const OpContext &ctx,
@@ -575,30 +616,6 @@ class ElemwiseBinaryOp : public OpBase {
   }
 
   template<typename xpu, typename OP>
-  static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs,
-                               const OpContext &ctx,
-                               const std::vector<TBlob> &inputs,
-                               const std::vector<OpReqType> &req,
-                               const std::vector<TBlob> &outputs) {
-    using namespace mxnet_op;
-    if (req[0] == kNullOp) return;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    CHECK_EQ(inputs.size(), 2U);
-    CHECK_EQ(outputs.size(), 1U);
-    MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-      MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
-        const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), 
inputs[1].Size())
-        + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
-        if (size != 0) {
-          Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
-          outputs[0].dptr<DType>(),
-          inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
-        }
-      });
-    });
-  }
-
-  template<typename xpu, typename OP>
   static void ComputeEx(const nnvm::NodeAttrs &attrs,
                         const OpContext &ctx,
                         const std::vector<NDArray> &inputs,
@@ -694,20 +711,8 @@ class ElemwiseBinaryOp : public OpBase {
                                      const std::vector<TBlob> &inputs,
                                      const std::vector<OpReqType> &req,
                                      const std::vector<TBlob> &outputs) {
-    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-      BackwardUseNone_<xpu, LOP, ROP, DType>(attrs, ctx, inputs, req, outputs);
-    });
-  }
-
-  template<typename xpu, typename LOP, typename ROP>
-  static inline void BackwardUseNoneWithHalf2(const nnvm::NodeAttrs &attrs,
-                                              const OpContext &ctx,
-                                              const std::vector<TBlob> &inputs,
-                                              const std::vector<OpReqType> 
&req,
-                                              const std::vector<TBlob> 
&outputs) {
-    MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
-      BackwardUseNone_<xpu, LOP, ROP, DType>(attrs, ctx, inputs, req, outputs);
-    });
+    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+    BackwardUseNone_<LOP, ROP>(attrs, s, inputs, req, outputs);
   }
 
   template<typename xpu, typename LOP, typename ROP>
@@ -751,20 +756,8 @@ class ElemwiseBinaryOp : public OpBase {
                                    const std::vector<TBlob> &inputs,
                                    const std::vector<OpReqType> &req,
                                    const std::vector<TBlob> &outputs) {
-    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
-      BackwardUseIn_<xpu, LOP, ROP, DType>(attrs, ctx, inputs, req, outputs);
-    });
-  }
-
-  template<typename xpu, typename LOP, typename ROP>
-  static inline void BackwardUseInWithHalf2(const nnvm::NodeAttrs &attrs,
-                                            const OpContext &ctx,
-                                            const std::vector<TBlob> &inputs,
-                                            const std::vector<OpReqType> &req,
-                                            const std::vector<TBlob> &outputs) 
{
-    MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
-      BackwardUseIn_<xpu, LOP, ROP, DType>(attrs, ctx, inputs, req, outputs);
-    });
+    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+    BackwardUseIn_<LOP, ROP>(attrs, s, inputs, req, outputs);
   }
 
   template<
@@ -863,4 +856,9 @@ class ElemwiseBinaryOp : public OpBase {
 
 }  // namespace op
 }  // namespace mxnet
+
+#ifdef __CUDACC__
+#include "elemwise_binary_op.cuh"
+#endif  // __CUDACC__
+
 #endif  // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_OP_H_
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu 
b/src/operator/tensor/elemwise_binary_op_basic.cu
index 16d7fc1..b21b08d 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_op_basic.cu
@@ -218,52 +218,51 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<gpu> 
*s,
 }
 
 NNVM_REGISTER_OP(elemwise_add)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, 
op::mshadow_op::plus>)
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, 
op::mshadow_op::plus>)
 .set_attr<FComputeEx>("FComputeEx<gpu>", ElemwiseBinaryOp::ComputeEx<gpu, 
op::mshadow_op::plus>);
 
 NNVM_REGISTER_OP(_grad_add)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, 
op::mshadow_op::plus>);
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, 
op::mshadow_op::plus>);
 
 NNVM_REGISTER_OP(_backward_add)
 .set_attr<FCompute>("FCompute<gpu>",
-                    ElemwiseBinaryOp::BackwardUseNoneWithHalf2<gpu, 
mshadow_op::identity,
+                    ElemwiseBinaryOp::BackwardUseNone<gpu, 
mshadow_op::identity,
                     mshadow_op::identity>);
 
 NNVM_REGISTER_OP(elemwise_sub)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<
-  gpu, op::mshadow_op::minus>)
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, 
op::mshadow_op::minus>)
 .set_attr<FComputeEx>("FComputeEx<gpu>", ElemwiseBinaryOp::ComputeEx<gpu, 
op::mshadow_op::minus>);
 
 NNVM_REGISTER_OP(_backward_sub)
 .set_attr<FCompute>("FCompute<gpu>",
-                    ElemwiseBinaryOp::BackwardUseNoneWithHalf2<gpu, 
mshadow_op::identity,
+                    ElemwiseBinaryOp::BackwardUseNone<gpu, 
mshadow_op::identity,
                     mshadow_op::negation>);
 
 NNVM_REGISTER_OP(elemwise_mul)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, 
op::mshadow_op::mul>)
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, 
op::mshadow_op::mul>)
 .set_attr<FComputeEx>("FComputeEx<gpu>",
   ElemwiseBinaryOp::ComputeDnsLRValueEx<gpu, op::mshadow_op::mul, true, true>);
 
 NNVM_REGISTER_OP(_backward_mul)
 .set_attr<FCompute>("FCompute<gpu>",
-                    ElemwiseBinaryOp::BackwardUseInWithHalf2<gpu, 
mshadow_op::right,
+                    ElemwiseBinaryOp::BackwardUseIn<gpu, mshadow_op::right,
                     mshadow_op::left>);
 
 NNVM_REGISTER_OP(elemwise_div)
 .set_attr<FCompute>("FCompute<gpu>",
-                    ElemwiseBinaryOp::ElemwiseBinaryOp::ComputeWithHalf2<gpu, 
op::mshadow_op::div>);
+                    ElemwiseBinaryOp::Compute<gpu, op::mshadow_op::div>);
 
 NNVM_REGISTER_OP(_backward_div)
 .set_attr<FCompute>("FCompute<gpu>",
-                    ElemwiseBinaryOp::BackwardUseInWithHalf2<gpu, 
mshadow_op::div_grad,
+                    ElemwiseBinaryOp::BackwardUseIn<gpu, mshadow_op::div_grad,
                     mshadow_op::div_rgrad>);
 
 NNVM_REGISTER_OP(_mod)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, 
mshadow_op::mod>);
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::Compute<gpu, 
mshadow_op::mod>);
 
 NNVM_REGISTER_OP(_backward_mod)
 .set_attr<FCompute>("FCompute<gpu>",
-  ElemwiseBinaryOp::BackwardUseInWithHalf2<gpu, mshadow_op::mod_grad, 
mshadow_op::mod_rgrad>);
+  ElemwiseBinaryOp::BackwardUseIn<gpu, mshadow_op::mod_grad, 
mshadow_op::mod_rgrad>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_scalar_op.cuh 
b/src/operator/tensor/elemwise_binary_scalar_op.cuh
new file mode 100644
index 0000000..062c187
--- /dev/null
+++ b/src/operator/tensor/elemwise_binary_scalar_op.cuh
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2020 by Contributors
+ * \file elemwise_binary_scalar_op.cuh
+ * \brief GPU helpers for binary elementwise operators with scalar
+ */
+
+#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_CUH_
+#define MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_CUH_
+
+#include <cuda_runtime.h>
+#include "../operator_common.h"
+#include "../../common/cuda_vectorization.cuh"
+
+#include <vector>
+
+#if MXNET_USE_CUDA
+
+namespace mxnet {
+namespace op {
+
+namespace binary_scalar {
+
+using common::cuda::VectorizedKernelLauncher;
+using common::cuda::VectorizedLoader;
+using common::cuda::VectorizedStorer;
+
+template <typename DType, int NumInputs, int NumOutputs>
+struct VectorizedKernelParams {
+  const DType* inputs[NumInputs];
+  DType* outputs[NumOutputs];
+  DType scalar;
+};
+
+template <bool aligned, typename DType, typename LType, typename OP, int req>
+__global__ void VectorizedBinaryScalarKernelFwd(const 
VectorizedKernelParams<DType, 1, 1> params,
+                                                const index_t N) {
+  VectorizedLoader<DType, LType, aligned> loader0(params.inputs[0], N);
+  VectorizedStorer<DType, LType, aligned> storer(params.outputs[0], N);
+
+  const index_t M = loader0.num_aligned_elements();
+
+  for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
+       tid < M;
+       tid += gridDim.x * blockDim.x) {
+    loader0.load(tid, N);
+    if (req == kAddTo) {
+      storer.load(tid, N);
+    }
+#pragma unroll
+    for (int i = 0; i < loader0.nvec(); ++i) {
+      DType temp = OP::Map(loader0.separate()[i],
+                           params.scalar);
+
+      if (req == kAddTo) {
+        storer.separate()[i] += temp;
+      } else {
+        storer.separate()[i] = temp;
+      }
+    }
+    storer.store(tid, N);
+  }
+}
+
+template <bool aligned, typename DType, typename LType, typename OP, int req>
+__global__ void VectorizedBinaryScalarKernelBwd(const 
VectorizedKernelParams<DType, 2, 1> params,
+                                                const index_t N) {
+  VectorizedLoader<DType, LType, aligned> ograd_loader(params.inputs[0], N);
+  VectorizedLoader<DType, LType, aligned> input_loader(params.inputs[1], N);
+  VectorizedStorer<DType, LType, aligned> storer(params.outputs[0], N);
+
+  const index_t M = ograd_loader.num_aligned_elements();
+
+  for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
+       tid < M;
+       tid += gridDim.x * blockDim.x) {
+    ograd_loader.load(tid, N);
+    input_loader.load(tid, N);
+    if (req == kAddTo) {
+      storer.load(tid, N);
+    }
+#pragma unroll
+    for (int i = 0; i < ograd_loader.nvec(); ++i) {
+      DType ograd = ograd_loader.separate()[i];
+      DType temp = ograd * OP::Map(input_loader.separate()[i],
+                                   params.scalar);
+
+      if (req == kAddTo) {
+        storer.separate()[i] += temp;
+      } else {
+        storer.separate()[i] = temp;
+      }
+    }
+    storer.store(tid, N);
+  }
+}
+
+template <typename DType, typename OP, int req>
+class VectorizedBinaryScalarFwd {
+ public:
+  using ParamType = VectorizedKernelParams<DType, 1, 1>;
+
+  template <bool aligned, typename LType>
+  static void Launch(const index_t blocks, const index_t threads,
+                     cudaStream_t stream,
+                     const ParamType params, const index_t lead_dim,
+                     const index_t /* other_dim */) {
+    VectorizedBinaryScalarKernelFwd<aligned, DType, LType, OP, req>
+      <<<blocks, threads, 0, stream>>>(params, lead_dim);
+  }
+};
+
+template <typename DType, typename OP, int req>
+class VectorizedBinaryScalarBwd {
+ public:
+  using ParamType = VectorizedKernelParams<DType, 2, 1>;
+
+  template <bool aligned, typename LType>
+  static void Launch(const index_t blocks, const index_t threads,
+                     cudaStream_t stream,
+                     const ParamType params, const index_t lead_dim,
+                     const index_t /* other_dim */) {
+    VectorizedBinaryScalarKernelBwd<aligned, DType, LType, OP, req>
+      <<<blocks, threads, 0, stream>>>(params, lead_dim);
+  }
+};
+
+}  // namespace binary_scalar
+
+template <typename OP>
+void BinaryScalarOp::Compute_(const nnvm::NodeAttrs &attrs,
+                              mshadow::Stream<gpu>* s,
+                              const std::vector<TBlob> &inputs,
+                              const std::vector<OpReqType> &req,
+                              const std::vector<TBlob> &outputs) {
+  using namespace binary_scalar;
+  if (req[0] == kNullOp) return;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  const double alpha = nnvm::get<double>(attrs.parsed);
+  MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      using LType = uint4;
+      using Kernel = VectorizedBinaryScalarFwd<DType, OP, Req>;
+
+      const index_t size = outputs[0].Size();
+      typename Kernel::ParamType params;
+      params.inputs[0] = inputs[0].dptr<DType>();
+      params.outputs[0] = outputs[0].dptr<DType>();
+      params.scalar = (DType)alpha;
+
+      VectorizedKernelLauncher<DType, LType, Kernel>(size, 1, s, params);
+    });
+  });
+}
+
+template <typename OP>
+void BinaryScalarOp::Backward_(const nnvm::NodeAttrs &attrs,
+                               mshadow::Stream<gpu>* s,
+                               const std::vector<TBlob> &inputs,
+                               const std::vector<OpReqType> &req,
+                               const std::vector<TBlob> &outputs) {
+  using namespace binary_scalar;
+  if (req[0] == kNullOp) return;
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+  const double alpha = nnvm::get<double>(attrs.parsed);
+  MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      using LType = uint4;
+      using Kernel = VectorizedBinaryScalarBwd<DType, OP, Req>;
+
+      const index_t size = outputs[0].Size();
+      typename Kernel::ParamType params;
+      params.inputs[0] = inputs[0].dptr<DType>();
+      params.inputs[1] = inputs[1].dptr<DType>();
+      params.outputs[0] = outputs[0].dptr<DType>();
+      params.scalar = (DType)alpha;
+
+      VectorizedKernelLauncher<DType, LType, Kernel>(size, 1, s, params);
+    });
+  });
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_CUDA
+#endif  // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_CUH_
diff --git a/src/operator/tensor/elemwise_binary_scalar_op.h 
b/src/operator/tensor/elemwise_binary_scalar_op.h
index 3e87028..f974332 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op.h
+++ b/src/operator/tensor/elemwise_binary_scalar_op.h
@@ -224,26 +224,44 @@ class BinaryScalarOp : public UnaryOp {
   }
 
  public:
-  template<typename xpu, typename OP>
-  static void Compute(const nnvm::NodeAttrs &attrs,
-                      const OpContext &ctx,
-                      const std::vector<TBlob> &inputs,
-                      const std::vector<OpReqType> &req,
-                      const std::vector<TBlob> &outputs) {
+  template<typename OP>
+  static void Compute_(const nnvm::NodeAttrs &attrs,
+                       mshadow::Stream<cpu>* s,
+                       const std::vector<TBlob> &inputs,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &outputs) {
     DCHECK_EQ(inputs.size(), 1);
     DCHECK_EQ(outputs.size(), 1);
     using namespace mshadow;
     using namespace mshadow::expr;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
     const double alpha = nnvm::get<double>(attrs.parsed);
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
-        mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
+        mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, cpu>::Launch(
           s, inputs[0].Size(), outputs[0].dptr<DType>(), 
inputs[0].dptr<DType>(), DType(alpha));
       });
     });
   }
 
+#if MXNET_USE_CUDA
+  template<typename OP>
+  static void Compute_(const nnvm::NodeAttrs &attrs,
+                       mshadow::Stream<gpu>* s,
+                       const std::vector<TBlob> &inputs,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &outputs);
+#endif
+
+  template<typename xpu, typename OP>
+  static void Compute(const nnvm::NodeAttrs &attrs,
+                      const OpContext &ctx,
+                      const std::vector<TBlob> &inputs,
+                      const std::vector<OpReqType> &req,
+                      const std::vector<TBlob> &outputs) {
+    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+    Compute_<OP>(attrs, s, inputs, req, outputs);
+  }
+
   template<typename xpu, typename OP>
   static void ComputeInt(const nnvm::NodeAttrs &attrs,
                          const OpContext &ctx,
@@ -335,26 +353,46 @@ class BinaryScalarOp : public UnaryOp {
     }
   }
 
-  template<typename xpu, typename OP>
-  static void Backward(const nnvm::NodeAttrs &attrs,
-                       const OpContext &ctx,
-                       const std::vector<TBlob> &inputs,
-                       const std::vector<OpReqType> &req,
-                       const std::vector<TBlob> &outputs) {
+  template<typename OP>
+  static void Backward_(const nnvm::NodeAttrs &attrs,
+                        mshadow::Stream<cpu>* s,
+                        const std::vector<TBlob> &inputs,
+                        const std::vector<OpReqType> &req,
+                        const std::vector<TBlob> &outputs) {
     using namespace mshadow;
     using namespace mshadow::expr;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
     const double alpha = nnvm::get<double>(attrs.parsed);
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
         mxnet::op::mxnet_op::Kernel<mxnet::op::mxnet_op::op_with_req<
-          mxnet::op::mxnet_op::backward_grad_tuned<OP>, Req>, xpu>::
+          mxnet::op::mxnet_op::backward_grad_tuned<OP>, Req>, cpu>::
           Launch(s, inputs[0].Size(), outputs[0].dptr<DType>(),
                  inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
                  DType(alpha));
       });
     });
   }
+
+#if MXNET_USE_CUDA
+  template<typename OP>
+  static void Backward_(const nnvm::NodeAttrs &attrs,
+                        mshadow::Stream<gpu>* s,
+                        const std::vector<TBlob> &inputs,
+                        const std::vector<OpReqType> &req,
+                        const std::vector<TBlob> &outputs);
+#endif
+
+  template<typename xpu, typename OP>
+  static void Backward(const nnvm::NodeAttrs &attrs,
+                       const OpContext &ctx,
+                       const std::vector<TBlob> &inputs,
+                       const std::vector<OpReqType> &req,
+                       const std::vector<TBlob> &outputs) {
+    using namespace mshadow;
+    using namespace mshadow::expr;
+    Stream<xpu> *s = ctx.get_stream<xpu>();
+    Backward_<OP>(attrs, s, inputs, req, outputs);
+  }
 };
 
 #define MXNET_OPERATOR_REGISTER_BINARY_SCALAR(name)                 \
@@ -375,4 +413,9 @@ class BinaryScalarOp : public UnaryOp {
 
 }  // namespace op
 }  // namespace mxnet
+
+#ifdef __CUDACC__
+#include "elemwise_binary_scalar_op.cuh"
+#endif
+
 #endif  // MXNET_OPERATOR_TENSOR_ELEMWISE_BINARY_SCALAR_OP_H_
diff --git a/src/operator/tensor/elemwise_binary_scalar_op_basic.cu 
b/src/operator/tensor/elemwise_binary_scalar_op_basic.cu
index 3c83920..3fd017f 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_scalar_op_basic.cu
@@ -57,22 +57,19 @@ NNVM_REGISTER_OP(_rdiv_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::rdiv>);
 
 NNVM_REGISTER_OP(_backward_rdiv_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu,
-  mshadow_op::rdiv_grad>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, 
mshadow_op::rdiv_grad>);
 
 NNVM_REGISTER_OP(_mod_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::mod>);
 
 NNVM_REGISTER_OP(_backward_mod_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<
-  gpu, mshadow_op::mod_grad>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, 
mshadow_op::mod_grad>);
 
 NNVM_REGISTER_OP(_rmod_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::rmod>);
 
 NNVM_REGISTER_OP(_backward_rmod_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<
-  gpu, mshadow_op::rmod_grad>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, 
mshadow_op::rmod_grad>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_scalar_op_extended.cu 
b/src/operator/tensor/elemwise_binary_scalar_op_extended.cu
index 2bd52d7..f09e40a 100644
--- a/src/operator/tensor/elemwise_binary_scalar_op_extended.cu
+++ b/src/operator/tensor/elemwise_binary_scalar_op_extended.cu
@@ -44,30 +44,25 @@ NNVM_REGISTER_OP(_power_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::power>);
 
 NNVM_REGISTER_OP(_backward_power_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<
-  gpu, mshadow_op::power_grad>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, 
mshadow_op::power_grad>);
 
 NNVM_REGISTER_OP(_rpower_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::rpower>);
 
 NNVM_REGISTER_OP(_backward_rpower_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<
-  gpu, mshadow_op::rpower_grad>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, 
mshadow_op::rpower_grad>);
 
 NNVM_REGISTER_OP(_hypot_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::hypot>);
 
 NNVM_REGISTER_OP(_backward_hypot_scalar)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<
-  gpu, mshadow_op::hypot_grad_left>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, 
mshadow_op::hypot_grad_left>);
 
 NNVM_REGISTER_OP(smooth_l1)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<
-  gpu, mshadow_op::smooth_l1_loss>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, 
mshadow_op::smooth_l1_loss>);
 
 NNVM_REGISTER_OP(_backward_smooth_l1)
-.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<
-  gpu, mshadow_op::smooth_l1_gradient>);
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, 
mshadow_op::smooth_l1_gradient>);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/elemwise_sum.cu 
b/src/operator/tensor/elemwise_sum.cu
index f9a2482..352c74e 100644
--- a/src/operator/tensor/elemwise_sum.cu
+++ b/src/operator/tensor/elemwise_sum.cu
@@ -24,10 +24,118 @@
 */
 #include "./elemwise_sum.h"
 #include "../../ndarray/ndarray_function.h"
+#include "../../common/cuda_vectorization.cuh"
 
 namespace mxnet {
 namespace op {
 
+using common::cuda::VectorizedKernelLauncher;
+using common::cuda::VectorizedLoader;
+using common::cuda::VectorizedStorer;
+
+namespace {
+
+constexpr size_t num_inputs_per_kernel = 4;
+
+template <typename DType, int NumInputs>
+struct VectorizedElementwiseSumKernelParams {
+  int num_inputs;
+  const DType* inputs[NumInputs];
+  DType* outputs[1];
+};
+
+template <bool aligned, typename DType, typename LType, int req>
+__launch_bounds__(mxnet::common::cuda::vectorized_kernel_thread_num)
+__global__ void VectorizedElementwiseSumKernel(
+    const VectorizedElementwiseSumKernelParams<DType, num_inputs_per_kernel> 
params,
+    const index_t N) {
+  VectorizedStorer<DType, LType, aligned> storer(params.outputs[0], N);
+
+  const index_t M = storer.num_aligned_elements();
+
+  for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
+      tid < M;
+      tid += gridDim.x * blockDim.x) {
+    if (req == kAddTo) {
+      storer.load(tid, N);
+    } else {
+#pragma unroll
+      for (int i = 0; i < storer.nvec(); ++i) {
+        storer.separate()[i] = 0;
+      }
+    }
+#pragma unroll
+    for (int i = 0; i < num_inputs_per_kernel; ++i) {
+      if (i < params.num_inputs) {
+        VectorizedLoader<DType, LType, aligned> loader(params.inputs[i], N);
+        loader.load(tid, N);
+#pragma unroll
+        for (int i = 0; i < loader.nvec(); ++i) {
+          storer.separate()[i] += loader.separate()[i];
+        }
+      }
+    }
+
+    storer.store(tid, N);
+  }
+}
+
+
+template <typename DType, int req>
+class VectorizedElementwiseSumFwd {
+ public:
+  using ParamType = VectorizedElementwiseSumKernelParams<DType, 
num_inputs_per_kernel>;
+
+  template <bool aligned, typename LType>
+  static void Launch(const index_t blocks, const index_t threads,
+                     cudaStream_t stream,
+                     const ParamType params, const index_t lead_dim,
+                     const index_t /* other_dim */) {
+    VectorizedElementwiseSumKernel<aligned, DType, LType, req>
+      <<<blocks, threads, 0, stream>>>(params, lead_dim);
+  }
+};
+
+void VectorizedElementwiseSum(const nnvm::NodeAttrs &attrs,
+                              const OpContext &ctx,
+                              const std::vector<TBlob> &inputs,
+                              const std::vector<OpReqType> &req,
+                              const std::vector<TBlob> &outputs) {
+  mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
+  if (req[0] == kNullOp) return;
+  CHECK_EQ(outputs.size(), 1U);
+  MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      using LType = uint2;
+      const index_t size = inputs[0].Size();
+      for (size_t i = 0; i < inputs.size(); i += num_inputs_per_kernel) {
+        if (i == 0) {
+          using Kernel = VectorizedElementwiseSumFwd<DType, Req>;
+          typename Kernel::ParamType params;
+          params.num_inputs = std::min(num_inputs_per_kernel, inputs.size() - 
i);
+          for (int j = 0; j < params.num_inputs; ++j) {
+            params.inputs[j] = inputs[i + j].dptr<DType>();
+          }
+          params.outputs[0] = outputs[0].dptr<DType>();
+          VectorizedKernelLauncher<DType, LType, Kernel>(size, 1, s, params);
+        } else {
+          /* During subsequent launches we need to
+             accumulate into the previous outputs
+          */
+          using Kernel = VectorizedElementwiseSumFwd<DType, kAddTo>;
+          typename Kernel::ParamType params;
+          params.num_inputs = std::min(num_inputs_per_kernel, inputs.size() - 
i);
+          for (int j = 0; j < params.num_inputs; ++j) {
+            params.inputs[j] = inputs[i + j].dptr<DType>();
+          }
+          params.outputs[0] = outputs[0].dptr<DType>();
+          VectorizedKernelLauncher<DType, LType, Kernel>(size, 1, s, params);
+        }
+      }
+    });
+  });
+}
+
 void ElementWiseSumComputeExGPU(const nnvm::NodeAttrs& attrs,
                                 const OpContext& ctx,
                                 const std::vector<NDArray>& inputs,
@@ -51,8 +159,10 @@ void ElementWiseSumComputeExGPU(const nnvm::NodeAttrs& 
attrs,
   }
 }
 
+}  // namespace
+
 NNVM_REGISTER_OP(add_n)
-.set_attr<FCompute>("FCompute<gpu>", ElementWiseSumComputeWithHalf2<gpu>)
+.set_attr<FCompute>("FCompute<gpu>", VectorizedElementwiseSum)
 .set_attr<FComputeEx>("FComputeEx<gpu>", ElementWiseSumComputeExGPU);
 
 }  // namespace op
diff --git a/src/operator/tensor/elemwise_sum.h 
b/src/operator/tensor/elemwise_sum.h
index 259c80d..d40ab4d 100644
--- a/src/operator/tensor/elemwise_sum.h
+++ b/src/operator/tensor/elemwise_sum.h
@@ -113,18 +113,6 @@ void ElementWiseSumCompute(const nnvm::NodeAttrs& attrs,
   });
 }
 
-template<typename xpu>
-void ElementWiseSumComputeWithHalf2(const nnvm::NodeAttrs& attrs,
-                                    const OpContext& ctx,
-                                    const std::vector<TBlob>& inputs,
-                                    const std::vector<OpReqType>& req,
-                                    const std::vector<TBlob>& outputs) {
-  CHECK_EQ(outputs.size(), 1U);
-  MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
-      ElementWiseSumCompute_<xpu, DType>(attrs, ctx, inputs, req, outputs);
-  });
-}
-
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_OPERATOR_TENSOR_ELEMWISE_SUM_H_
diff --git a/src/operator/tensor/elemwise_unary_op.cuh 
b/src/operator/tensor/elemwise_unary_op.cuh
new file mode 100644
index 0000000..8688a8b
--- /dev/null
+++ b/src/operator/tensor/elemwise_unary_op.cuh
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2020 by Contributors
+ * \file elemwise_unary_op.cuh
+ * \brief GPU helpers for unary elementwise operators
+ */
+
+#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_UNARY_OP_CUH_
+#define MXNET_OPERATOR_TENSOR_ELEMWISE_UNARY_OP_CUH_
+
+#include <cuda_runtime.h>
+#include "../operator_common.h"
+#include "../../common/cuda_vectorization.cuh"
+
+#include <vector>
+
+#if MXNET_USE_CUDA
+
+namespace mxnet {
+namespace op {
+
+namespace unary {
+
+using common::cuda::VectorizedKernelLauncher;
+using common::cuda::VectorizedLoader;
+using common::cuda::VectorizedStorer;
+
+template <typename DType, int NumInputs, int NumOutputs>
+struct VectorizedKernelParams {
+  const DType* inputs[NumInputs];
+  DType* outputs[NumOutputs];
+};
+
+template <bool aligned, typename DType, typename LType, typename OP, int req>
+__global__ void VectorizedUnaryScalarKernelFwd(const 
VectorizedKernelParams<DType, 1, 1> params,
+                                               const index_t N) {
+  VectorizedLoader<DType, LType, aligned> loader(params.inputs[0], N);
+  VectorizedStorer<DType, LType, aligned> storer(params.outputs[0], N);
+
+  const index_t M = loader.num_aligned_elements();
+
+  for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
+       tid < M;
+       tid += gridDim.x * blockDim.x) {
+    loader.load(tid, N);
+    if (req == kAddTo) {
+      storer.load(tid, N);
+    }
+#pragma unroll
+    for (int i = 0; i < loader.nvec(); ++i) {
+      DType temp = OP::Map(loader.separate()[i]);
+
+      if (req == kAddTo) {
+        storer.separate()[i] += temp;
+      } else {
+        storer.separate()[i] = temp;
+      }
+    }
+    storer.store(tid, N);
+  }
+}
+
+template <typename DType, typename OP, int req>
+class VectorizedUnaryScalarFwd {
+ public:
+  using ParamType = VectorizedKernelParams<DType, 1, 1>;
+
+  template <bool aligned, typename LType>
+  static void Launch(const index_t blocks, const index_t threads,
+                     cudaStream_t stream,
+                     const ParamType params, const index_t lead_dim,
+                     const index_t /* other_dim */) {
+    VectorizedUnaryScalarKernelFwd<aligned, DType, LType, OP, req>
+      <<<blocks, threads, 0, stream>>>(params, lead_dim);
+  }
+};
+
+}  // namespace unary
+
+template<typename OP>
+void UnaryOp::Compute_(const nnvm::NodeAttrs& attrs,
+                     mshadow::Stream<gpu>* s,
+                     const std::vector<TBlob>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<TBlob>& outputs) {
+  using namespace unary;
+  if (req[0] == kNullOp) return;
+  CHECK_EQ(inputs.size(), 1U);
+  CHECK_EQ(outputs.size(), 1U);
+  MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      using LType = uint4;
+      using Kernel = VectorizedUnaryScalarFwd<DType, OP, Req>;
+
+      const index_t size = outputs[0].Size();
+      typename Kernel::ParamType params;
+      params.inputs[0] = inputs[0].dptr<DType>();
+      params.outputs[0] = outputs[0].dptr<DType>();
+
+      VectorizedKernelLauncher<DType, LType, Kernel>(size, 1, s, params);
+    });
+  });
+}
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // MXNET_USE_CUDA
+#endif  // MXNET_OPERATOR_TENSOR_ELEMWISE_UNARY_OP_CUH_
diff --git a/src/operator/tensor/elemwise_unary_op.h 
b/src/operator/tensor/elemwise_unary_op.h
index dcbd53a..86686c6 100644
--- a/src/operator/tensor/elemwise_unary_op.h
+++ b/src/operator/tensor/elemwise_unary_op.h
@@ -235,23 +235,42 @@ class UnaryOp : public OpBase {
     }
   }
 
-  template<typename xpu, typename OP>
-  static void Compute(const nnvm::NodeAttrs& attrs,
-                      const OpContext& ctx,
-                      const std::vector<TBlob>& inputs,
-                      const std::vector<OpReqType>& req,
-                      const std::vector<TBlob>& outputs) {
-    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  template<typename OP>
+  static void Compute_(const nnvm::NodeAttrs& attrs,
+                       mshadow::Stream<cpu>* s,
+                       const std::vector<TBlob>& inputs,
+                       const std::vector<OpReqType>& req,
+                       const std::vector<TBlob>& outputs) {
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
         if (inputs[0].Size() != 0) {
-          mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
+          mxnet_op::Kernel<mxnet_op::op_with_req<OP, Req>, cpu>::Launch(
             s, inputs[0].Size(), outputs[0].dptr<DType>(), 
inputs[0].dptr<DType>());
         }
       });
     });
   }
 
+#if MXNET_USE_CUDA
+  template<typename OP>
+  static void Compute_(const nnvm::NodeAttrs& attrs,
+                       mshadow::Stream<gpu>* s,
+                       const std::vector<TBlob>& inputs,
+                       const std::vector<OpReqType>& req,
+                       const std::vector<TBlob>& outputs);
+
+#endif
+
+  template<typename xpu, typename OP>
+  static void Compute(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs) {
+    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+    Compute_<OP>(attrs, s, inputs, req, outputs);
+  }
+
   template<typename xpu, typename OP>
   static void ComputeInt(const nnvm::NodeAttrs& attrs,
                       const OpContext& ctx,
@@ -344,23 +363,6 @@ class UnaryOp : public OpBase {
   }
 #endif
 
-  template<typename xpu, typename op>
-  static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs,
-                               const OpContext &ctx,
-                               const std::vector<TBlob> &inputs,
-                               const std::vector<OpReqType> &req,
-                               const std::vector<TBlob> &outputs) {
-    using namespace mshadow;
-    using namespace mxnet_op;
-    Stream<xpu> *s = ctx.get_stream<xpu>();
-    CHECK_EQ(inputs.size(), 1U);
-    CHECK_EQ(outputs.size(), 1U);
-    MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
-      Kernel<op, xpu>::Launch(s, outputs[0].Size(),
-                              outputs[0].dptr<DType>(), 
inputs[0].dptr<DType>());
-    });
-  }
-
   template<typename xpu>
   static void IdentityCompute(const nnvm::NodeAttrs& attrs,
                               const OpContext& ctx,
@@ -877,4 +879,8 @@ void NumpyNanToNumOpBackward(const nnvm::NodeAttrs& attrs,
 }  // namespace op
 }  // namespace mxnet
 
+#ifdef __CUDACC__
+#include "elemwise_unary_op.cuh"
+#endif
+
 #endif  // MXNET_OPERATOR_TENSOR_ELEMWISE_UNARY_OP_H_
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu 
b/src/operator/tensor/elemwise_unary_op_basic.cu
index e5b60b1..7c05507 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cu
+++ b/src/operator/tensor/elemwise_unary_op_basic.cu
@@ -22,6 +22,7 @@
  * \brief GPU Implementation of unary functions.
  */
 #include "./elemwise_binary_op.h"
+#include "./elemwise_unary_op.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/tensor/elemwise_unary_op_pow.cu 
b/src/operator/tensor/elemwise_unary_op_pow.cu
index 4dbdf34..287a2e8 100644
--- a/src/operator/tensor/elemwise_unary_op_pow.cu
+++ b/src/operator/tensor/elemwise_unary_op_pow.cu
@@ -22,6 +22,7 @@
  * \brief GPU Implementation of power (x^k for fixed k) functions.
  */
 #include "./elemwise_binary_op.h"
+#include "./elemwise_unary_op.h"
 
 namespace mxnet {
 namespace op {
diff --git a/src/operator/tensor/elemwise_unary_op_trig.cu 
b/src/operator/tensor/elemwise_unary_op_trig.cu
index 8e28b9c..f5e9d1c 100644
--- a/src/operator/tensor/elemwise_unary_op_trig.cu
+++ b/src/operator/tensor/elemwise_unary_op_trig.cu
@@ -22,6 +22,7 @@
  * \brief GPU Implementation of unary trigonometric function.
  */
 #include "./elemwise_binary_op.h"
+#include "./elemwise_unary_op.h"
 
 namespace mxnet {
 namespace op {
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index df4a77f..481cd00 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -9894,6 +9894,84 @@ def test_elemwise_sum_for_gradient_accumulation():
         assert stored_grad['write'] == stored_grad['add']
         assert stored_grad['write'] == 2 * nrepeat
 
+@with_seed()
+def test_elementwise_ops_on_misaligned_input():
+    a = mx.nd.array([1,2,3,4], dtype='float16')
+    b = mx.nd.array([1,2,3,4], dtype='float16')
+
+    c = a[1:3]
+    d = b[1:3]
+    # Note: testing just elemwise_add since all elemwise_ops
+    #       share the implementation
+    mx.nd.elemwise_add(c, d, out=c)
+    mx.nd.waitall()
+
+    a = mx.nd.array([1,2,3,4], dtype='float16')
+    b = mx.nd.array([1,2,3,4], dtype='float16')
+
+    c = a[0:3]
+    d = b[0:3]
+    mx.nd.elemwise_add(c, d, out=c)
+    mx.nd.waitall()
+    assert a[3].asscalar() == 4.0
+
+@with_seed()
+def test_broadcast_ops_on_misaligned_input():
+    dtypes = ['float16', 'float32', 'float64']
+    lead_dims = [2,3,4,6,10]
+
+    for dtype in dtypes:
+        for lead_dim in lead_dims:
+            for both_ways in [False, True]:
+                shape = list(rand_shape_2d()) + [lead_dim]
+                small_shape = [shape[0], 1, lead_dim]
+                if both_ways:
+                    # Broadcast in both ways [1, K, L] x [M, 1, L]
+                    big_shape = [1, shape[1], lead_dim]
+                else:
+                    big_shape = shape
+                size = np.product(shape)
+                small_size = np.product(small_shape)
+                big_size = np.product(big_shape)
+                a = mx.nd.arange(5000)
+                b = mx.nd.arange(5000)
+                e = mx.nd.arange(5000)
+                c = a[1:big_size + 1].reshape(big_shape)
+                d = b[1:small_size + 1].reshape(small_shape)
+                f = e[1:size + 1].reshape(shape)
+                mx.nd.broadcast_add(c, d, out=f)
+                expected = c.asnumpy() + d.asnumpy()
+                mx.nd.waitall()
+                assert_almost_equal(f, expected)
+
+@with_seed()
+def test_broadcast_ops_on_misaligned_input_oneside():
+    dtypes = ['float16', 'float32', 'float64']
+    lead_dims = [2,3,4,6,10]
+
+    for dtype in dtypes:
+        for lead_dim in lead_dims:
+            for both_ways in [False, True]:
+                shape = list(rand_shape_2d()) + [lead_dim]
+                small_shape = [shape[0], shape[1], 1]
+                if both_ways:
+                    # Broadcast in both ways [1, K, L] x [M, 1, 1]
+                    big_shape = [1, shape[1], lead_dim]
+                else:
+                    big_shape = shape
+                size = np.product(shape)
+                small_size = np.product(small_shape)
+                big_size = np.product(big_shape)
+                a = mx.nd.arange(5000)
+                b = mx.nd.arange(5000)
+                e = mx.nd.arange(5000)
+                c = a[1:big_size + 1].reshape(big_shape)
+                d = b[1:small_size + 1].reshape(small_shape)
+                f = e[1:size + 1].reshape(shape)
+                mx.nd.broadcast_add(c, d, out=f)
+                expected = c.asnumpy() + d.asnumpy()
+                mx.nd.waitall()
+                assert_almost_equal(f, expected)
 
 if __name__ == '__main__':
     import nose

Reply via email to