This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.6.x by this push:
     new 9abb151  Backport to 1.6 (#16773, #16781, #16783, #16716, #16699, 
#16728, #16769, #16792) (#16832)
9abb151 is described below

commit 9abb1513982aa0e43cd4bda0568651fa28ec0938
Author: Przemyslaw Tredak <ptre...@nvidia.com>
AuthorDate: Fri Nov 15 17:50:38 2019 -0800

    Backport to 1.6 (#16773, #16781, #16783, #16716, #16699, #16728, #16769, 
#16792) (#16832)
    
    * Fix nightly build (#16773)
    
    * Remove dependency on tvmop.conf
    
    * Fix binaries dependencies for ni nightly
    
    * Add comments
    
    * Update tvmop.py
    
    * Fix rebase
    
    * Fix (#16781)
    
    * Speed fused_op compilation by caching ptx and jit-compiled functions 
(#16783)
    
    * [Numpy] Fix collect_params().zero_grad() in gluon numpy interface (#16716)
    
    * fix zero_grad
    
    * Update parameter.py
    
    * add test
    
    * fix
    
    * Mixed data type binary ops (#16699)
    
    * support mixed-precision binary operations
    
    * improvement for documentations and error messages
    
    * Support boolean elemwise/broadcast binary add, multiply and true_divide 
(#16728)
    
    * support pure boolean elemwise/broadcast binary op
    
    * switch to unique_tpr
    
    * fix the test error
    
    * Fix rtrue_divide grad (#16769)
    
    * Fix rtrue_divide_scalar
    
    * More tests
    
    * Fix numpy-compatible mean output type for integer inputs (#16792)
    
    * fix mean output type for integer inputs
    
    * enable for windows
---
 python/mxnet/gluon/parameter.py                    |  12 +-
 python/mxnet/ndarray/numpy/_op.py                  |  40 ++
 python/mxnet/numpy/multiarray.py                   |  40 ++
 python/mxnet/tvmop.py                              |  14 +-
 src/common/utils.h                                 |  30 +-
 src/executor/pointwise_fusion_pass.cc              |  27 +-
 src/ndarray/ndarray_function-inl.h                 |   2 +-
 src/operator/fusion/fused_op-inl.h                 |   4 +-
 src/operator/fusion/fused_op.cc                    |  50 ++-
 src/operator/fusion/fused_op.cu                    | 225 +++++++-----
 src/operator/fusion/fused_op.h                     |  43 ++-
 src/operator/mshadow_op.h                          | 114 +++++-
 src/operator/mxnet_op.h                            |  14 +-
 src/operator/numpy/np_broadcast_reduce_op.h        |  12 +-
 src/operator/numpy/np_broadcast_reduce_op_value.cc |  11 +-
 src/operator/numpy/np_elemwise_broadcast_op.cc     |  97 ++++-
 src/operator/numpy/np_elemwise_broadcast_op.cu     |  37 +-
 src/operator/numpy/np_elemwise_broadcast_op.h      | 404 +++++++++++++++++++++
 src/operator/numpy/np_true_divide.cc               |   4 +-
 src/operator/operator_tune-inl.h                   |   9 +-
 src/operator/operator_tune.cc                      |  10 +-
 src/operator/tensor/broadcast_reduce-inl.cuh       |  13 -
 src/operator/tensor/broadcast_reduce-inl.h         |  15 +-
 src/operator/tensor/broadcast_reduce_op.h          |   6 +-
 src/operator/tensor/elemwise_binary_broadcast_op.h |  32 ++
 src/operator/tensor/elemwise_binary_op.h           |  30 ++
 tests/nightly/JenkinsfileForBinaries               |   6 +-
 .../JenkinsfileForMBCC                             |   2 +-
 tests/python/gpu/test_fusion.py                    |  25 +-
 tests/python/unittest/test_numpy_gluon.py          |  17 +
 tests/python/unittest/test_numpy_ndarray.py        | 107 +++++-
 tests/python/unittest/test_numpy_op.py             | 229 ++++++++++--
 32 files changed, 1390 insertions(+), 291 deletions(-)

diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 957dc2c..067a357 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -27,7 +27,6 @@ __all__ = ['DeferredInitializationError', 'Parameter', 
'Constant',
 from collections import OrderedDict, defaultdict
 import warnings
 import numpy as np
-import mxnet as mx
 
 from ..base import mx_real_t, MXNetError
 from .. import symbol, ndarray, initializer, context
@@ -896,15 +895,20 @@ class ParameterDict(object):
                 continue
             for g in p.list_grad():
                 if g.stype == 'row_sparse':
-                    mx.ndarray.zeros_like(g, out=g)
+                    ndarray.zeros_like(g, out=g)
                 else:
                     arrays[g.context].append(g)
 
         if len(arrays) == 0:
             return
 
-        for arr in arrays.values():
-            mx.nd.reset_arrays(*arr, num_arrays=len(arr))
+        if is_np_array():
+            for arr in arrays.values():
+                for ele in arr:
+                    ele[()] = 0
+        else:
+            for arr in arrays.values():
+                ndarray.reset_arrays(*arr, num_arrays=len(arr))
 
     def reset_ctx(self, ctx):
         """Re-assign all Parameters to other contexts.
diff --git a/python/mxnet/ndarray/numpy/_op.py 
b/python/mxnet/ndarray/numpy/_op.py
index c215159..ff404a7 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -522,6 +522,14 @@ def add(x1, x2, out=None, **kwargs):
     -------
     add : ndarray or scalar
         The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 
are scalars.
+
+    Notes
+    -----
+    This operator now supports automatic type promotion. The resulting type 
will be determined
+    according to the following rules:
+        * If both inputs are of floating number types, the output is the more 
precise type.
+        * If only one of the inputs is floating number type, the result is 
that type.
+        * If both inputs are of integer types (including boolean), not 
supported yet.
     """
     return _ufunc_helper(x1, x2, _npi.add, _np.add, _npi.add_scalar, None, out)
 
@@ -548,6 +556,14 @@ def subtract(x1, x2, out=None, **kwargs):
     -------
     subtract : ndarray or scalar
         The difference of x1 and x2, element-wise. This is a scalar if both x1 
and x2 are scalars.
+
+    Notes
+    -----
+    This operator now supports automatic type promotion. The resulting type 
will be determined
+    according to the following rules:
+        * If both inputs are of floating number types, the output is the more 
precise type.
+        * If only one of the inputs is floating number type, the result is 
that type.
+        * If both inputs are of integer types (including boolean), not 
supported yet.
     """
     return _ufunc_helper(x1, x2, _npi.subtract, _np.subtract, 
_npi.subtract_scalar,
                          _npi.rsubtract_scalar, out)
@@ -575,6 +591,14 @@ def multiply(x1, x2, out=None, **kwargs):
     out : ndarray or scalar
         The multiplication of x1 and x2, element-wise. This is a scalar if 
both x1 and x2
         are scalars.
+
+    Notes
+    -----
+    This operator now supports automatic type promotion. The resulting type 
will be determined
+    according to the following rules:
+        * If both inputs are of floating number types, the output is the more 
precise type.
+        * If only one of the inputs is floating number type, the result is 
that type.
+        * If both inputs are of integer types (including boolean), not 
supported yet.
     """
     return _ufunc_helper(x1, x2, _npi.multiply, _np.multiply, 
_npi.multiply_scalar, None, out)
 
@@ -602,6 +626,14 @@ def divide(x1, x2, out=None, **kwargs):
     -------
     out : ndarray or scalar
         This is a scalar if both x1 and x2 are scalars.
+
+    Notes
+    -----
+    This operator now supports automatic type promotion. The resulting type 
will be determined
+    according to the following rules:
+        * If both inputs are of floating number types, the output is the more 
precise type.
+        * If only one of the inputs is floating number type, the result is 
that type.
+        * If both inputs are of integer types (including boolean), the output 
is of float32 type.
     """
     return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, 
_npi.true_divide_scalar,
                          _npi.rtrue_divide_scalar, out)
@@ -632,6 +664,14 @@ def true_divide(x1, x2, out=None):
     -------
     out : ndarray or scalar
         This is a scalar if both x1 and x2 are scalars.
+
+    Notes
+    -----
+    This operator now supports automatic type promotion. The resulting type 
will be determined
+    according to the following rules:
+        * If both inputs are of floating number types, the output is the more 
precise type.
+        * If only one of the inputs is floating number type, the result is 
that type.
+        * If both inputs are of integer types (including boolean), the output 
is of float32 type.
     """
     return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, 
_npi.true_divide_scalar,
                          _npi.rtrue_divide_scalar, out)
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index 85bd2ac..c623f67 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -2399,6 +2399,14 @@ def add(x1, x2, out=None, **kwargs):
     add : ndarray or scalar
         The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 
are scalars.
 
+    Notes
+    -----
+    This operator now supports automatic type promotion. The resulting type 
will be determined
+    according to the following rules:
+        * If both inputs are of floating number types, the output is the more 
precise type.
+        * If only one of the inputs is floating number type, the result is 
that type.
+        * If both inputs are of integer types (including boolean), not 
supported yet.
+
     Examples
     --------
     >>> np.add(1.0, 4.0)
@@ -2437,6 +2445,14 @@ def subtract(x1, x2, out=None, **kwargs):
     subtract : ndarray or scalar
         The difference of x1 and x2, element-wise. This is a scalar if both x1 
and x2 are scalars.
 
+    Notes
+    -----
+    This operator now supports automatic type promotion. The resulting type 
will be determined
+    according to the following rules:
+        * If both inputs are of floating number types, the output is the more 
precise type.
+        * If only one of the inputs is floating number type, the result is 
that type.
+        * If both inputs are of integer types (including boolean), not 
supported yet.
+
     Examples
     --------
     >>> np.subtract(1.0, 4.0)
@@ -2473,6 +2489,14 @@ def multiply(x1, x2, out=None, **kwargs):
     out : ndarray or scalar
         The difference of x1 and x2, element-wise. This is a scalar if both x1 
and x2 are scalars.
 
+    Notes
+    -----
+    This operator now supports automatic type promotion. The resulting type 
will be determined
+    according to the following rules:
+        * If both inputs are of floating number types, the output is the more 
precise type.
+        * If only one of the inputs is floating number type, the result is 
that type.
+        * If both inputs are of integer types (including boolean), not 
supported yet.
+
     Examples
     --------
     >>> np.multiply(2.0, 4.0)
@@ -2511,6 +2535,14 @@ def divide(x1, x2, out=None, **kwargs):
     out : ndarray or scalar
         This is a scalar if both x1 and x2 are scalars.
 
+    Notes
+    -----
+    This operator now supports automatic type promotion. The resulting type 
will be determined
+    according to the following rules:
+        * If both inputs are of floating number types, the output is the more 
precise type.
+        * If only one of the inputs is floating number type, the result is 
that type.
+        * If both inputs are of integer types (including boolean), the output 
is of float32 type.
+
     Examples
     --------
     >>> np.true_divide(x, 4)
@@ -2545,6 +2577,14 @@ def true_divide(x1, x2, out=None):
     out : ndarray or scalar
         This is a scalar if both x1 and x2 are scalars.
 
+    Notes
+    -----
+    This operator now supports automatic type promotion. The resulting type 
will be determined
+    according to the following rules:
+        * If both inputs are of floating number types, the output is the more 
precise type.
+        * If only one of the inputs is floating number type, the result is 
that type.
+        * If both inputs are of integer types (including boolean), the output 
is of float32 type.
+
     Examples
     --------
     >>> x = np.arange(5)
diff --git a/python/mxnet/tvmop.py b/python/mxnet/tvmop.py
index 9ec278a..29e6f0d 100644
--- a/python/mxnet/tvmop.py
+++ b/python/mxnet/tvmop.py
@@ -21,6 +21,7 @@ from .runtime import Features
 
 if Features().is_enabled("TVM_OP"):
     import json
+    import logging
 
     from ._ctypes.space import _set_tvm_op_config
     from .base import check_call, _LIB, c_str
@@ -31,7 +32,12 @@ if Features().is_enabled("TVM_OP"):
     check_call(_LIB.MXLoadTVMOp(c_str(_LIB_TVM_OP[0])))
 
     # op sch config
-    _CONF_TVM_OP = find_conf_path("tvmop")
-    with open(_CONF_TVM_OP[0], "r") as f:
-        ret = ConfigSpaces.from_json_dict(json.load(f))
-    _set_tvm_op_config(ret)
+    try:
+        _CONF_TVM_OP = find_conf_path("tvmop")
+    except RuntimeError as e:
+        logging.warning("TVM config file missing, falling back to default 
schedule", exc_info=True)
+    else:
+        logging.info("TVM op config has been loaded")
+        with open(_CONF_TVM_OP[0], "r") as f:
+            ret = ConfigSpaces.from_json_dict(json.load(f))
+        _set_tvm_op_config(ret)
diff --git a/src/common/utils.h b/src/common/utils.h
index b919cb3..0e3e354 100644
--- a/src/common/utils.h
+++ b/src/common/utils.h
@@ -365,6 +365,30 @@ inline bool ContainsStorageType(const std::vector<int>& 
ndstypes,
   return false;
 }
 
+inline std::string dtype_string(const int dtype) {
+  switch (dtype) {
+    case mshadow::kFloat32:
+      return "float";
+    case mshadow::kFloat64:
+      return "double";
+    case mshadow::kFloat16:
+      return "half";
+    case mshadow::kUint8:
+      return "unsigned char";
+    case mshadow::kInt8:
+      return "char";
+    case mshadow::kInt32:
+      return "int";
+    case mshadow::kInt64:
+      return "long long";
+    case mshadow::kBool:
+      return "bool";
+    default:
+      LOG(FATAL) << "Unknown type enum " << dtype;
+  }
+  return "unknown";
+}
+
 /*! \brief get string representation of dispatch_mode */
 inline std::string dispatch_mode_string(const DispatchMode x) {
   switch (x) {
@@ -842,7 +866,7 @@ inline bool is_float(const int dtype) {
   return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == 
mshadow::kFloat16;
 }
 
-inline int more_precise_type(const int type1, const int type2) {
+inline int get_more_precise_type(const int type1, const int type2) {
   if (type1 == type2) return type1;
   if (is_float(type1) && is_float(type2)) {
     if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) {
@@ -870,12 +894,12 @@ inline int more_precise_type(const int type1, const int 
type2) {
   return mshadow::kInt8;
 }
 
-inline int np_binary_out_type(const int type1, const int type2) {
+inline int np_binary_out_infer_type(const int type1, const int type2) {
   if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) ||
       (type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) {
     return mshadow::kInt32;
   }
-  return more_precise_type(type1, type2);
+  return get_more_precise_type(type1, type2);
 }
 
 }  // namespace common
diff --git a/src/executor/pointwise_fusion_pass.cc 
b/src/executor/pointwise_fusion_pass.cc
index c6e2405..6fe2140 100644
--- a/src/executor/pointwise_fusion_pass.cc
+++ b/src/executor/pointwise_fusion_pass.cc
@@ -83,15 +83,7 @@ namespace {
     auto node = nnvm::Node::Create();
     subgraph_sym.outputs = subgraph.outputs;
     
node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym));
-    std::ostringstream name_oss;
-    // the name of the new node will be the concatenation of all the node 
names in the subgraph
-    DFSVisit(subgraph.outputs, [&name_oss](const nnvm::NodePtr n) {
-      if (n->op() != nullptr)
-        name_oss << n->op()->name << "_";
-    });
-    auto subgraph_name = name_oss.str();
-    subgraph_name.pop_back();
-    node->attrs.name = subgraph_name;
+    node->attrs.name = "FusedOp";
     node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
     node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
     node->attrs.op = Op::Get("_FusedOp");
@@ -152,7 +144,8 @@ Graph ReplaceSubgraphsPointwise(Graph&& g, const 
std::vector<NodeRawPtrSet>& sub
         auto it = node->control_deps.begin();
         static auto& is_fusion = 
Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
         std::vector<nnvm::NodePtr> new_control_deps;
-        while (it != node->control_deps.end()) {
+        // Use the first control dependency to get the inferattr helper
+        if (it != node->control_deps.end()) {
           if (subgraph_set.count(it->get())) {
             new_control_deps.push_back(*it);
           } else {
@@ -160,8 +153,7 @@ Graph ReplaceSubgraphsPointwise(Graph&& g, const 
std::vector<NodeRawPtrSet>& sub
               uint32_t node_id = subgraph_node->control_deps.size();
               subgraph_node->control_deps.push_back(*it);
               auto helper_node = op::MakeNode("_FusedOpOutHelper",
-                                              subgraph_node->attrs.name + "_"
-                                              + node->attrs.name + 
"_outhelper",
+                                              "FusedOp_" + node->attrs.name + 
"_outhelper",
                                               nullptr,
                                               nullptr,
                                               nullptr);
@@ -180,6 +172,17 @@ Graph ReplaceSubgraphsPointwise(Graph&& g, const 
std::vector<NodeRawPtrSet>& sub
       }
     });
 
+    std::ostringstream name_oss;
+    // the name of the new node will be the concatenation of all the node 
names in the subgraph
+    DFSVisit(subgraph.outputs, [&name_oss](const nnvm::NodePtr n) {
+      if (n->op() != nullptr) {
+        name_oss << n->op()->name << "_";
+      }
+    });
+    auto subgraph_name = name_oss.str();
+    subgraph_name.pop_back();
+    subgraph_node->attrs.name = subgraph_name;
+
     const auto& index = subgraph.indexed_graph();
     DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const 
nnvm::NodePtr& node) {
       for (auto &e : node->control_deps) {
diff --git a/src/ndarray/ndarray_function-inl.h 
b/src/ndarray/ndarray_function-inl.h
index d494f08..6ac586a 100644
--- a/src/ndarray/ndarray_function-inl.h
+++ b/src/ndarray/ndarray_function-inl.h
@@ -379,7 +379,7 @@ void EvalRandom<DEVICE, GenNegBinomialDistribution>(
 template<>
 void Eval<DEVICE>(const real_t &rhs, TBlob *ret, RunContext ctx) {
   mshadow::Stream<DEVICE> *s = ctx.get_stream<DEVICE>();
-  MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, {
+  MSHADOW_TYPE_SWITCH_WITH_BOOL(ret->type_flag_, DType, {
     ret->FlatTo2D<DEVICE, DType>(s) = DType(rhs);
   });
 }
diff --git a/src/operator/fusion/fused_op-inl.h 
b/src/operator/fusion/fused_op-inl.h
index 3085bfd..2966fe2 100644
--- a/src/operator/fusion/fused_op-inl.h
+++ b/src/operator/fusion/fused_op-inl.h
@@ -982,11 +982,9 @@ const char kernel_begin[] = R"code(
 const int tid = threadIdx.x + blockIdx.x * blockDim.x;
 for (int i = tid; i < N; i+= gridDim.x * blockDim.x) {
     int offset = i*nvec;
-
 )code";
 
-const char kernel_end[] = R"code(
-}
+const char kernel_end[] = R"code(}
 }
 )code";
 
diff --git a/src/operator/fusion/fused_op.cc b/src/operator/fusion/fused_op.cc
index 071215b..5c83c30 100644
--- a/src/operator/fusion/fused_op.cc
+++ b/src/operator/fusion/fused_op.cc
@@ -49,31 +49,30 @@ void FusedOpParamParser(nnvm::NodeAttrs* attrs) {
   attrs->parsed = FusedOpPtr(new FusedOp(attrs, param));
 }
 
-FusedOp::FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig& config) {
-  this->inputs_ = std::vector<FusedOpEntry>(config.num_inputs);
-  this->outputs_ = std::vector<FusedOpEntry>(config.num_outputs);
-  this->subgraph_ = nnvm::Graph();
-  this->subgraph_.outputs = attrs->subgraphs[0]->outputs;
-  this->initialized_ = false;
-  this->cc_major_ = -1;
-  this->cc_minor_ = -1;
+FusedOp::FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig& config) :
+    initialized_(false),
+    kernel_function_dev_id_(-1) {
+  inputs_ = std::vector<FusedOpEntry>(config.num_inputs);
+  outputs_ = std::vector<FusedOpEntry>(config.num_outputs);
+  subgraph_ = nnvm::Graph();
+  subgraph_.outputs = attrs->subgraphs[0]->outputs;
 }
 
 bool FusedOp::InferShape(const nnvm::NodeAttrs &attrs,
                          std::vector<mxnet::TShape> *in_attrs,
                          std::vector<mxnet::TShape> *out_attrs) {
-  this->subgraph_.attrs.erase("shape");
-  this->subgraph_.attrs.erase("shape_inputs");
+  subgraph_.attrs.erase("shape");
+  subgraph_.attrs.erase("shape_inputs");
   std::vector<mxnet::TShape> input_shapes(*in_attrs);
-  this->subgraph_ = mxnet::exec::InferShape(std::move(this->subgraph_),
-                                          std::move(input_shapes),
-                                          "__shape__");
+  subgraph_ = mxnet::exec::InferShape(std::move(subgraph_),
+                                      std::move(input_shapes),
+                                      "__shape__");
 
-  const auto& g = this->subgraph_.indexed_graph();
+  const auto& g = subgraph_.indexed_graph();
   const auto& input_nids = g.input_nodes();
 
   std::vector<mxnet::TShape> out_shapes;
-  const std::vector<mxnet::TShape> shapes = 
this->subgraph_.GetAttr<mxnet::ShapeVector>("shape");
+  const std::vector<mxnet::TShape> shapes = 
subgraph_.GetAttr<mxnet::ShapeVector>("shape");
   for (auto& e : g.outputs()) {
     out_shapes.push_back(shapes[g.entry_id(e)]);
   }
@@ -105,18 +104,18 @@ bool FusedOp::InferShape(const nnvm::NodeAttrs &attrs,
 bool FusedOp::InferType(const nnvm::NodeAttrs &attrs,
                         std::vector<int> *in_attrs,
                         std::vector<int> *out_attrs) {
-  this->subgraph_.attrs.erase("dtype");
-  this->subgraph_.attrs.erase("dtype_inputs");
+  subgraph_.attrs.erase("dtype");
+  subgraph_.attrs.erase("dtype_inputs");
   std::vector<int> input_types(*in_attrs);
-  this->subgraph_ = mxnet::exec::InferType(std::move(this->subgraph_),
-                                         std::move(input_types),
-                                         "__dtype__");
+  subgraph_ = mxnet::exec::InferType(std::move(subgraph_),
+                                     std::move(input_types),
+                                     "__dtype__");
 
-  const auto& g = this->subgraph_.indexed_graph();
+  const auto& g = subgraph_.indexed_graph();
   const auto& input_nids = g.input_nodes();
 
   std::vector<int> out_types;
-  const std::vector<int> types = 
this->subgraph_.GetAttr<nnvm::DTypeVector>("dtype");
+  const std::vector<int> types = subgraph_.GetAttr<nnvm::DTypeVector>("dtype");
   for (auto& e : g.outputs()) {
     out_types.push_back(types[g.entry_id(e)]);
   }
@@ -149,10 +148,9 @@ template <typename Attr>
 std::tuple<const nnvm::NodePtr,
            std::vector<Attr>,
            std::vector<Attr>>
-  FusedOp::GetAttrs(const std::string& attr_name,
-                                                                  const 
uint32_t node_id) {
-  const auto& g = this->subgraph_.indexed_graph();
-  const std::vector<Attr> attrs = 
this->subgraph_.GetAttr<std::vector<Attr>>(attr_name);
+FusedOp::GetAttrs(const std::string& attr_name, const uint32_t node_id) {
+  const auto& g = subgraph_.indexed_graph();
+  const std::vector<Attr> attrs = 
subgraph_.GetAttr<std::vector<Attr>>(attr_name);
   const auto& node = g[node_id];
   std::vector<Attr> inputs, outputs;
   for (const auto& e : node.inputs) {
diff --git a/src/operator/fusion/fused_op.cu b/src/operator/fusion/fused_op.cu
index f6df38b..78988f1 100644
--- a/src/operator/fusion/fused_op.cu
+++ b/src/operator/fusion/fused_op.cu
@@ -163,9 +163,33 @@ void AddPointerAndShape(const TBlob& data,
   });
 }
 
+// Obtain compilation log from the program.
+std::string GetCompileLog(nvrtcProgram program) {
+  size_t log_size_including_null;
+  NVRTC_CALL(nvrtcGetProgramLogSize(program, &log_size_including_null));
+  // For most std::string implementations, this is probably 1 char bigger than 
needed.  OK though.
+  std::string log(log_size_including_null, '\0');
+  NVRTC_CALL(nvrtcGetProgramLog(program, &log[0]));
+  // Make sure the string reflects the true size (so minus the null 
terminator).
+  log.resize(log_size_including_null - 1);
+  return log;
+}
+
+// Obtain compilation result (ptx assembly) from the program.
+std::string GetPtx(nvrtcProgram program) {
+  size_t ptx_size_including_null;
+  NVRTC_CALL(nvrtcGetPTXSize(program, &ptx_size_including_null));
+  // For most std::string implementations, this is probably 1 char bigger than 
needed.  OK though.
+  std::string ptx(ptx_size_including_null, '\0');
+  NVRTC_CALL(nvrtcGetPTX(program, &ptx[0]));
+  // Make sure the string reflects the true size (so minus the null 
terminator).
+  ptx.resize(ptx_size_including_null - 1);
+  return ptx;
+}
+
 }  // namespace
 
-void FusedOp::GenerateCode(int kernel_index, const std::vector<OpReqType> &req,
+std::string FusedOp::GenerateCode(const std::vector<OpReqType> &req,
                            const std::vector<int> &in_dtypes,
                            const std::vector<int> &out_dtypes,
                            const std::vector<int> &in_ndims,
@@ -175,7 +199,7 @@ void FusedOp::GenerateCode(int kernel_index, const 
std::vector<OpReqType> &req,
                            const int nvec,
                            const std::string &kernel_name,
                            std::vector<uint32_t>* check_shapes) {
-  const auto& g = this->subgraph_.indexed_graph();
+  const auto& g = subgraph_.indexed_graph();
   std::string code = "";
   int temp_name_counter = 0;
   using NodeEntry = nnvm::IndexedGraph::NodeEntry;
@@ -459,16 +483,11 @@ void FusedOp::GenerateCode(int kernel_index, const 
std::vector<OpReqType> &req,
     ++counter;
   }
 
-  this->code_[kernel_index] = code;
-
   // Add boilerplate and type information
-  if (dmlc::GetEnv("MXNET_FUSION_VERBOSE", false)) {
-    LOG(INFO) << code_[kernel_index];
-  }
   std::string kernel_params = "";
   std::string tensor_params = "";
   nnvm::Symbol sym;
-  sym.outputs = this->subgraph_.outputs;
+  sym.outputs = subgraph_.outputs;
   const std::vector<std::string> input_names = 
sym.ListInputNames(nnvm::Symbol::kAll);
   size_t num_params = in_dtypes.size() + out_dtypes.size();
   size_t i = 0;
@@ -513,85 +532,102 @@ void FusedOp::GenerateCode(int kernel_index, const 
std::vector<OpReqType> &req,
   }
   kernel_params += tensor_params;
 
-  code_[kernel_index] = std::string(fusion::fp16_support_string) + "\n" +
-          fusion::type_support_string + "\n" +
-          fusion::function_definitions + "\n" +
-          fusion::backward_function_definitions + "\n" +
-          aux_code + "\n" +
-          "__launch_bounds__(" + std::to_string(FusedOp::NTHREADS) + ")\n" +
-          "__global__ void FusedKernel_" + kernel_name +
-          "(size_t N, " + kernel_params + ") {\n" +
-          fusion::kernel_begin + "\n" +
-          code_[kernel_index] + "\n" +
-          fusion::kernel_end;
+  // Create kernel source (minus the common header)
+  return aux_code + "\n" +
+         "__launch_bounds__(" + std::to_string(FusedOp::NTHREADS) + ")\n" +
+         "__global__ void FusedKernel_" + kernel_name +
+         "(size_t N, " + kernel_params + ") {\n" +
+         fusion::kernel_begin + "\n" +
+         code + "\n" +
+         fusion::kernel_end;
 }
 
-void FusedOp::CompileCode(int kernel_index, const std::string &kernel_name) {
+CUfunction FusedOp::CompileCode(const std::string &code,
+                                const std::string &kernel_name,
+                                int dev_id) {
   // Guard NVRTC calls
   std::lock_guard<std::mutex> lock_nvrtc(mutex_);
-  nvrtcProgram program;
-  NVRTC_CALL(
-      nvrtcCreateProgram(&program,                                  // prog
-                         &code_[kernel_index][0],                              
   // buffer
-                         (kernel_name + "_kernel.cu").c_str(),      // name
-                         0,                                         // num 
headers
-                         NULL,                                      // headers
-                         NULL));                                    // include 
names
-  std::string gpu_arch = "--gpu-architecture=compute_" +
-                         std::to_string(this->cc_major_) +
-                         std::to_string(this->cc_minor_);
-
-  const char *opts[] = {gpu_arch.c_str(),
-                        "--std=c++11",
-                        "-default-device"};
-  const std::string kernel_name_demangled = "FusedKernel_" + kernel_name;
-  NVRTC_CALL(nvrtcAddNameExpression(program, (kernel_name_demangled).c_str()));
-
-  nvrtcResult compileResult = nvrtcCompileProgram(program,  // prog
-                                                  3,        // num options
-                                                  opts);    // options
-  // Obtain compilation log from the program.
-  size_t log_size;
-  NVRTC_CALL(nvrtcGetProgramLogSize(program, &log_size));
-  std::string log(log_size, '\0');
-  NVRTC_CALL(nvrtcGetProgramLog(program, &log[0]));
-  CHECK_EQ(compileResult, NVRTC_SUCCESS)
-    << "NVRTC Compilation failed. Please set environment variable 
MXNET_USE_FUSION to 0.\n" << log;
-  // Obtain PTX from the program.
-  size_t ptx_size;
-  NVRTC_CALL(nvrtcGetPTXSize(program, &ptx_size));
-  ptx_[kernel_index].reserve(ptx_size);
-  NVRTC_CALL(nvrtcGetPTX(program, &ptx_[kernel_index][0]));
-  const char *name;
-  NVRTC_CALL(nvrtcGetLoweredName(program,
-                                 kernel_name_demangled.c_str(),
-                                 &name));
-  kernel_name_[kernel_index] = name;
-  // Destroy the program.
-  NVRTC_CALL(nvrtcDestroyProgram(&program));
-  int device;
-  CUdevice cu_device;
-  CUcontext context;
-  CUmodule module;
-  CUDA_CALL(cudaGetDevice(&device));
-  CUDA_DRIVER_CALL(cuDeviceGet(&cu_device, device));
-  CUDA_DRIVER_CALL(cuDevicePrimaryCtxRetain(&context, cu_device));
-  CUDA_DRIVER_CALL(cuModuleLoadData(&module, &ptx_[kernel_index][0]));
-  CUDA_DRIVER_CALL(cuModuleGetFunction(&kernel_[kernel_index],
-                                       module,
-                                       kernel_name_[kernel_index].c_str()));
+  // Local class for value type of compile cache
+  struct KernelInfo {
+    std::string mangled_name;
+    std::string ptx;
+    std::vector<CUfunction> functions;
+  };
+  // Maps from the cuda source code (minus header) to the ptx and jit-compiled 
CUfunctions.
+  using KernelCache = std::map<std::string, KernelInfo>;
+  // Per-gpu-architecture compiled kernel cache with jit-compiled function for 
each device context
+  static std::map<int32_t, KernelCache> compiled_kernels;
+  int sm_arch = SMArch(dev_id);
+  KernelCache& compiled_kernels_this_arch = compiled_kernels[sm_arch];  // 
make null map as needed
+  KernelInfo& kinfo = compiled_kernels_this_arch[code];                 // 
make KernelInfo as needed
+  if (kinfo.ptx.size() == 0) {
+    // It's the first time we've seen this kernel, so we need to generate the 
ptx and mangled_name.
+    static std::string common_header =
+        std::string(fusion::fp16_support_string) + "\n" +
+        fusion::type_support_string + "\n" +
+        fusion::function_definitions + "\n" +
+        fusion::backward_function_definitions + "\n";
+    std::string code_with_header = common_header + code;
+    // If verbose mode, output kernel source, though not including the common 
header
+    if (dmlc::GetEnv("MXNET_FUSION_VERBOSE", false)) {
+      LOG(INFO) << "\n" << std::string(80, '-') << "\n" << code;
+    }
+    if (compiled_kernels_this_arch.size() == CACHESIZE_WARN_THRESHOLD + 1 &&
+        dmlc::GetEnv("MXNET_FUSION_SIZE_WARNING", true)) {
+      LOG(WARNING) << "The number of different fused ops exceeds " << 
CACHESIZE_WARN_THRESHOLD
+                   << ".  Set MXNET_FUSION_SIZE_WARNING=0 to quiet this 
warning.";
+    }
+    nvrtcProgram program;
+    NVRTC_CALL(nvrtcCreateProgram(&program,                                  
// prog
+                                  &code_with_header[0],                      
// buffer
+                                  (kernel_name + "_kernel.cu").c_str(),      
// name
+                                  0,                                         
// num headers
+                                  NULL,                                      
// headers
+                                  NULL));                                    
// include names
+
+    std::string gpu_arch_arg = "--gpu-architecture=compute_" + 
std::to_string(sm_arch);
+    const char *opts[] = {gpu_arch_arg.c_str(),
+                          "--std=c++11",
+                          "-default-device"};
+    const std::string kernel_name_demangled = "FusedKernel_" + kernel_name;
+    NVRTC_CALL(nvrtcAddNameExpression(program, 
(kernel_name_demangled).c_str()));
+
+    nvrtcResult compileResult = nvrtcCompileProgram(program,  // prog
+                                                    3,        // num options
+                                                    opts);    // options
+    CHECK_EQ(compileResult, NVRTC_SUCCESS)
+        << "NVRTC Compilation failed. Please set environment variable 
MXNET_USE_FUSION to 0.\n"
+        << GetCompileLog(program);
+
+    kinfo.ptx = GetPtx(program);
+    const char *mangled_name;
+    NVRTC_CALL(nvrtcGetLoweredName(program,
+                                   kernel_name_demangled.c_str(),
+                                   &mangled_name));
+    kinfo.mangled_name = mangled_name;
+    // Destroy the program.
+    NVRTC_CALL(nvrtcDestroyProgram(&program));
+  }
+  // Ensure function array is deep enough to index by dev_id
+  while (kinfo.functions.size() <= static_cast<size_t>(dev_id))
+    kinfo.functions.push_back(static_cast<CUfunction>(nullptr));
+  // Jit-compile ptx for the device as needed
+  if (kinfo.functions[dev_id] == static_cast<CUfunction>(nullptr)) {
+    // Make sure driver context is set to the proper device
+    CUdevice cu_device;
+    CUcontext context;
+    CUDA_DRIVER_CALL(cuDeviceGet(&cu_device, dev_id));
+    CUDA_DRIVER_CALL(cuDevicePrimaryCtxRetain(&context, cu_device));
+    // Jit-compile ptx for the driver's current context
+    CUmodule module;
+    CUDA_DRIVER_CALL(cuModuleLoadData(&module, kinfo.ptx.c_str()));
+    CUDA_DRIVER_CALL(cuModuleGetFunction(&kinfo.functions[dev_id],
+                                         module,
+                                         kinfo.mangled_name.c_str()));
+  }
+  return kinfo.functions[dev_id];
 }
 
-bool FusedOp::CheckComputeCapability(const OpContext &ctx) {
-  const int dev_id = ctx.run_ctx.ctx.dev_id;
-  const int cc_major = ComputeCapabilityMajor(dev_id);
-  const int cc_minor = ComputeCapabilityMinor(dev_id);
-
-  const bool ret = cc_major == this->cc_major_ && cc_minor == this->cc_minor_;
-  this->cc_major_ = cc_major;
-  this->cc_minor_ = cc_minor;
-  return ret;
-}
 
 void FusedOp::CheckShapesAndTypes(const std::vector<TBlob> &inputs,
                                   const std::vector<TBlob> &outputs,
@@ -665,23 +701,30 @@ void FusedOp::Forward<gpu>(const nnvm::NodeAttrs& attrs,
   const auto& node_shapes = intermediate_shapes_[0].internal_attr;
   const auto& node_dtypes = intermediate_dtypes_[0].internal_attr;
 
-  // Check and save compute capability of the current GPU
-  if (!CheckComputeCapability(ctx)) initialized_ = false;
+  int dev_id = ctx.run_ctx.ctx.dev_id;
 
+  // A change between training and inference modes may require different 
kernel functions
   initialized_ = initialized_ && (req == saved_reqs_);
   saved_reqs_ = req;
 
   if (!initialized_) {
-    this->GenerateCode(0, req, in_dtypes, out_dtypes, in_ndims, out_ndims,
+    const auto& code = GenerateCode(req, in_dtypes, out_dtypes, in_ndims, 
out_ndims,
                        node_shapes, node_dtypes, nvec, attrs.name, 
&check_shape_args_);
-    this->CompileCode(0, attrs.name);
+    kernel_functions_[fusion::kGeneral] = CompileCode(code, attrs.name, 
dev_id);
     if (check_shape_args_.size() > 0) {
-        this->GenerateCode(1, req, in_dtypes, out_dtypes, in_ndims, out_ndims,
+      const auto& code = GenerateCode(req, in_dtypes, out_dtypes, in_ndims, 
out_ndims,
                            node_shapes, node_dtypes, nvec, attrs.name, NULL);
-        this->CompileCode(1, attrs.name);
+      kernel_functions_[fusion::kShapeOptimized] = CompileCode(code, 
attrs.name, dev_id);
     }
     initialized_ = true;
+    kernel_function_dev_id_ = dev_id;
   }
+
+  // A change in device would force recompiling, but this is unexpected so 
signal as an error
+  if (dev_id != kernel_function_dev_id_)
+    LOG(FATAL) << "Fused op compiled for device " << kernel_function_dev_id_
+               <<  ", not expecting switch to device " << dev_id;
+
   Stream<gpu>* s = ctx.get_stream<gpu>();
   auto stream = Stream<gpu>::GetStream(s);
   std::vector<void*> args;
@@ -713,18 +756,18 @@ void FusedOp::Forward<gpu>(const nnvm::NodeAttrs& attrs,
   for (auto &ptr : ptrs) {
     args.push_back(reinterpret_cast<void *>(&ptr));
   }
-  int kernel_index = 0;
+  int kernel_variant = fusion::kGeneral;
   if (check_shape_args_.size() > 0) {
-      kernel_index = 1;
+    kernel_variant = fusion::kShapeOptimized;
       for (const auto &shape_id : check_shape_args_) {
           const auto& shape = node_shapes[shape_id];
           if (shape[shape.ndim()-1] % nvec != 0) {
-              kernel_index = 0;
+            kernel_variant = fusion::kGeneral;
           }
       }
   }
   CUDA_DRIVER_CALL(
-      cuLaunchKernel(kernel_[kernel_index],
+      cuLaunchKernel(kernel_functions_[kernel_variant],
         num_blocks, 1, 1,          // grid dim
         FusedOp::NTHREADS, 1, 1,   // block dim
         0, stream,                 // shared mem and stream
diff --git a/src/operator/fusion/fused_op.h b/src/operator/fusion/fused_op.h
index 035e5432..24603ac 100644
--- a/src/operator/fusion/fused_op.h
+++ b/src/operator/fusion/fused_op.h
@@ -34,6 +34,12 @@
 
 namespace mxnet {
 
+namespace fusion {
+  enum KernelVariants {kGeneral, kShapeOptimized,
+    kNumKernelVariants  // Not a variant- leave this at the end
+  };
+}
+
 struct FusedOpConfig : public dmlc::Parameter<FusedOpConfig> {
   int num_inputs;
   int num_outputs;
@@ -53,6 +59,7 @@ struct FusedOpEntry {
 class FusedOp {
  public:
   static const int NTHREADS = 512;
+  static const int CACHESIZE_WARN_THRESHOLD = 10000;
 
   explicit FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig& config);
   ~FusedOp() {}
@@ -120,20 +127,20 @@ class FusedOp {
   }
 
  private:
-  void GenerateCode(int kernel_index,
-                    const std::vector<OpReqType> &req,
-                    const std::vector<int> &in_dtypes,
-                    const std::vector<int> &out_dtypes,
-                    const std::vector<int> &in_ndims,
-                    const std::vector<int> &out_ndims,
-                    const mxnet::ShapeVector &node_shapes,
-                    const std::vector<int> &node_dtypes,
-                    const int nvec,
-                    const std::string& kernel_name,
-                    std::vector<uint32_t> *check_shapes);
-  void CompileCode(int kernel_index,
-                   const std::string &kernel_name);
-  bool CheckComputeCapability(const OpContext &ctx);
+  std::string GenerateCode(const std::vector<OpReqType> &req,
+                           const std::vector<int> &in_dtypes,
+                           const std::vector<int> &out_dtypes,
+                           const std::vector<int> &in_ndims,
+                           const std::vector<int> &out_ndims,
+                           const mxnet::ShapeVector &node_shapes,
+                           const std::vector<int> &node_dtypes,
+                           const int nvec,
+                           const std::string& kernel_name,
+                           std::vector<uint32_t> *check_shapes);
+
+  CUfunction CompileCode(const std::string &code,
+                         const std::string &kernel_name, int dev_id);
+
   void CheckShapesAndTypes(const std::vector<TBlob> &inputs,
                            const std::vector<TBlob> &outputs,
                            std::vector<int> *in_dtypes,
@@ -145,7 +152,6 @@ class FusedOp {
   std::vector<FusedOpEntry> inputs_;
   std::vector<FusedOpEntry> outputs_;
 
-  std::string code_[2];
   nnvm::Graph subgraph_;
 
   template <typename T>
@@ -173,12 +179,9 @@ class FusedOp {
   std::vector<uint32_t> extra_shape_args_;
   std::vector<uint32_t> check_shape_args_;
 
-  std::string ptx_[2];
-  std::string kernel_name_[2];
-  CUfunction kernel_[2];
+  CUfunction kernel_functions_[fusion::kNumKernelVariants];
   bool initialized_;
-  int cc_major_;
-  int cc_minor_;
+  int kernel_function_dev_id_;
 
   static std::mutex mutex_;
   std::mutex my_mutex_;
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 1ece97b..d563f25 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -97,6 +97,18 @@ using std::is_integral;
     } \
   }
 
+#define MXNET_BINARY_MATH_OP_NC_WITH_BOOL(name, expr) \
+  struct name : public mxnet_op::tunable  { \
+    template<typename DType, \
+             typename std::enable_if<!std::is_same<DType, bool>::value, 
int>::type = 0> \
+    MSHADOW_XINLINE static DType Map(DType a, DType b) { \
+      return (expr); \
+    } \
+    MSHADOW_XINLINE static bool Map(bool a, bool b) { \
+      return (expr); \
+    } \
+  }
+
 #define MXNET_BINARY_LOGIC_OP_NC(name, expr) \
   struct name : public mxnet_op::tunable  { \
     template<typename DType> \
@@ -192,13 +204,107 @@ MXNET_BINARY_MATH_OP_NC(left, a);
 
 MXNET_BINARY_MATH_OP_NC(right, b);
 
-MXNET_BINARY_MATH_OP_NC(mul, a * b);
+#ifndef _WIN32
+struct mixed_plus {
+  template<typename DType,
+           typename std::enable_if<std::is_integral<DType>::value, int>::type 
= 0>
+  MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, 
mshadow::half::half_t b) {
+    return static_cast<mshadow::half::half_t>(a) + b;
+  }
+
+  template<typename DType,
+           typename std::enable_if<std::is_same<DType, 
mshadow::half::half_t>::value ||
+                                   std::is_integral<DType>::value, int>::type 
= 0>
+  MSHADOW_XINLINE static float Map(DType a, float b) {
+    return static_cast<float>(a) + b;
+  }
+
+  template<typename DType,
+           typename std::enable_if<std::is_same<DType, 
mshadow::half::half_t>::value ||
+                                   std::is_same<DType, float>::value ||
+                                   std::is_integral<DType>::value, int>::type 
= 0>
+  MSHADOW_XINLINE static double Map(DType a, double b) {
+    return static_cast<double>(a) + b;
+  }
+};
+
+struct mixed_minus {
+  template<typename DType,
+           typename std::enable_if<std::is_integral<DType>::value, int>::type 
= 0>
+  MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, 
mshadow::half::half_t b) {
+    return static_cast<mshadow::half::half_t>(a) - b;
+  }
+
+  template<typename DType,
+           typename std::enable_if<std::is_same<DType, 
mshadow::half::half_t>::value ||
+                                   std::is_integral<DType>::value, int>::type 
= 0>
+  MSHADOW_XINLINE static float Map(DType a, float b) {
+    return static_cast<float>(a) - b;
+  }
+
+  template<typename DType,
+           typename std::enable_if<std::is_same<DType, 
mshadow::half::half_t>::value ||
+                                   std::is_same<DType, float>::value ||
+                                   std::is_integral<DType>::value, int>::type 
= 0>
+  MSHADOW_XINLINE static double Map(DType a, double b) {
+    return static_cast<double>(a) - b;
+  }
+};
+
+struct mixed_rminus {
+  template<typename DType,
+           typename std::enable_if<std::is_integral<DType>::value, int>::type 
= 0>
+  MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, 
mshadow::half::half_t b) {
+    return b - static_cast<mshadow::half::half_t>(a);
+  }
+
+  template<typename DType,
+           typename std::enable_if<std::is_same<DType, 
mshadow::half::half_t>::value ||
+                                   std::is_integral<DType>::value, int>::type 
= 0>
+  MSHADOW_XINLINE static float Map(DType a, float b) {
+    return b - static_cast<float>(a);
+  }
+
+  template<typename DType,
+           typename std::enable_if<std::is_same<DType, 
mshadow::half::half_t>::value ||
+                                   std::is_same<DType, float>::value ||
+                                   std::is_integral<DType>::value, int>::type 
= 0>
+  MSHADOW_XINLINE static double Map(DType a, double b) {
+    return b - static_cast<double>(a);
+  }
+};
+
+struct mixed_mul {
+  template<typename DType,
+           typename std::enable_if<std::is_integral<DType>::value, int>::type 
= 0>
+  MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, 
mshadow::half::half_t b) {
+    return static_cast<mshadow::half::half_t>(a) * b;
+  }
+
+  template<typename DType,
+           typename std::enable_if<std::is_same<DType, 
mshadow::half::half_t>::value ||
+                                   std::is_integral<DType>::value, int>::type 
= 0>
+  MSHADOW_XINLINE static float Map(DType a, float b) {
+    return static_cast<float>(a) * b;
+  }
+
+  template<typename DType,
+           typename std::enable_if<std::is_same<DType, 
mshadow::half::half_t>::value ||
+                                   std::is_same<DType, float>::value ||
+                                   std::is_integral<DType>::value, int>::type 
= 0>
+  MSHADOW_XINLINE static double Map(DType a, double b) {
+    return static_cast<double>(a) * b;
+  }
+};
+#endif
+
+MXNET_BINARY_MATH_OP_NC_WITH_BOOL(mul, a * b);
 
-MXNET_BINARY_MATH_OP_NC(div, a / b);
+MXNET_BINARY_MATH_OP_NC_WITH_BOOL(div, a / b);
 
-MXNET_BINARY_MATH_OP_NC(plus, a + b);
+MXNET_BINARY_MATH_OP_NC_WITH_BOOL(plus, a + b);
 
-MXNET_BINARY_MATH_OP_NC(minus, a - b);
+MXNET_BINARY_MATH_OP_NC_WITH_BOOL(minus, a - b);
 
 MXNET_UNARY_MATH_OP(negation, -a);
 
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 5d297a5..b15117f 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -859,14 +859,17 @@ struct op_with_req {
 
   /*! \brief inputs are two tensors with a float output tensor */
   template<typename DType,
-           typename std::enable_if<std::is_integral<DType>::value, int>::type 
= 0>
+           typename std::enable_if<std::is_same<DType, 
mshadow::half::half_t>::value ||
+                                   std::is_integral<DType>::value, int>::type 
= 0>
   MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, 
const float *rhs) {
     KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
   }
 
   /*! \brief inputs are two tensors with a double output tensor */
   template<typename DType,
-           typename std::enable_if<std::is_integral<DType>::value, int>::type 
= 0>
+           typename std::enable_if<std::is_same<DType, 
mshadow::half::half_t>::value ||
+                                   std::is_same<DType, float>::value ||
+                                   std::is_integral<DType>::value, int>::type 
= 0>
   MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, 
const double *rhs) {
     KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
   }
@@ -883,14 +886,17 @@ struct op_with_req {
 
   /*! \brief inputs are two tensors with a float output tensor */
   template<typename DType,
-           typename std::enable_if<std::is_integral<DType>::value, int>::type 
= 0>
+           typename std::enable_if<std::is_same<DType, 
mshadow::half::half_t>::value ||
+                                   std::is_integral<DType>::value, int>::type 
= 0>
   MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, 
const float value) {
     KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value));
   }
 
   /*! \brief inputs are two tensors with a double output tensor */
   template<typename DType,
-           typename std::enable_if<std::is_integral<DType>::value, int>::type 
= 0>
+           typename std::enable_if<std::is_same<DType, 
mshadow::half::half_t>::value ||
+                                   std::is_same<DType, float>::value ||
+                                   std::is_integral<DType>::value, int>::type 
= 0>
   MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, 
const double value) {
     KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value));
   }
diff --git a/src/operator/numpy/np_broadcast_reduce_op.h 
b/src/operator/numpy/np_broadcast_reduce_op.h
index 4951f62..3566323 100644
--- a/src/operator/numpy/np_broadcast_reduce_op.h
+++ b/src/operator/numpy/np_broadcast_reduce_op.h
@@ -52,6 +52,7 @@ struct NumpyReduceAxesParam : public 
dmlc::Parameter<NumpyReduceAxesParam> {
       .add_enum("int8", mshadow::kInt8)
       .add_enum("int32", mshadow::kInt32)
       .add_enum("int64", mshadow::kInt64)
+      .add_enum("bool", mshadow::kBool)
       .set_default(dmlc::optional<int>())
       .describe("The type of the returned array and of the accumulator in 
which the elements are "
                 "summed. The dtype of a is used by default unless a has an 
integer dtype of less "
@@ -221,15 +222,15 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
                             const std::vector<TBlob>& inputs,
                             const std::vector<OpReqType>& req,
                             const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
   if (req[0] == kNullOp) return;
   const NumpyReduceAxesParam& param = 
nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
   if (param.initial.has_value()) {
     LOG(FATAL) << "initial is not supported yet";
   }
+  Stream<xpu>* s = ctx.get_stream<xpu>();
   if (inputs[0].shape_.Size() == 0 && outputs[0].shape_.Size() != 0) {
     using namespace mxnet_op;
-    using namespace mshadow;
-    Stream<xpu>* s = ctx.get_stream<xpu>();
     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
       Kernel<set_zero, xpu>::Launch(s, outputs[0].shape_.Size(), 
outputs[0].dptr<DType>());
     });
@@ -246,6 +247,13 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs,
       LOG(FATAL) << "Only reduce op: `sum` is supported for boolean ndarrays";
     }
     TVMOpReduce(ctx, inputs[0], param.axis, outputs[0], req[0], reducer_name);
+    if (normalize) {
+      using namespace mshadow::expr;
+      MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+        auto out = outputs[0].FlatTo2D<xpu, OType>(s);
+        out /= scalar<OType>(inputs[0].Size()/outputs[0].Size());
+      });
+    }
     return;
   }
 #endif
diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc 
b/src/operator/numpy/np_broadcast_reduce_op_value.cc
index 435fe1d..fb13356 100644
--- a/src/operator/numpy/np_broadcast_reduce_op_value.cc
+++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc
@@ -257,13 +257,14 @@ inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs,
   const NumpyReduceAxesParam &param = 
nnvm::get<NumpyReduceAxesParam>(attrs.parsed);
 
   if (param.dtype.has_value()) {
-    if (IsIntType(in_attrs->at(0)) && !IsIntType(param.dtype.value())) {
-      LOG(FATAL) << "Output cannot be float type when input is integer type 
for now";
-    }
     TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
   } else {
-    TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
-    TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+    if (common::is_float(in_attrs->at(0))) {
+      TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
+      TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
+    } else {
+      TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
+    }
   }
 
   return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc 
b/src/operator/numpy/np_elemwise_broadcast_op.cc
index c206ad4..a76e59d 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cc
@@ -23,8 +23,7 @@
  * \brief CPU Implementation of basic functions for elementwise numpy binary 
broadcast operator.
  */
 
-#include "../tensor/elemwise_binary_broadcast_op.h"
-#include "../tensor/elemwise_binary_scalar_op.h"
+#include "./np_elemwise_broadcast_op.h"
 
 namespace mxnet {
 namespace op {
@@ -55,17 +54,99 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs,
   .add_argument("data", "NDArray-or-Symbol", "source input")        \
   .add_argument("scalar", "float", "scalar input")
 
+bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs,
+                                   std::vector<int>* in_attrs,
+                                   std::vector<int>* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  const int ltype = in_attrs->at(0);
+  const int rtype = in_attrs->at(1);
+  if (ltype != -1 && rtype != -1 && (ltype != rtype)) {
+    // Only when both input types are known and not the same, we enter the 
mixed-precision mode
+    TYPE_ASSIGN_CHECK(*out_attrs, 0, common::np_binary_out_infer_type(ltype, 
rtype));
+  } else {
+    return ElemwiseType<2, 1>(attrs, in_attrs, out_attrs);
+  }
+  return true;
+}
 
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_add)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
op::mshadow_op::plus>)
+#ifndef _WIN32
+#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name)                
\
+  NNVM_REGISTER_OP(name)                                                       
\
+  .set_num_inputs(2)                                                           
\
+  .set_num_outputs(1)                                                          
\
+  .set_attr<nnvm::FListInputNames>("FListInputNames",                          
\
+    [](const NodeAttrs& attrs) {                                               
\
+      return std::vector<std::string>{"lhs", "rhs"};                           
\
+    })                                                                         
\
+  .set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)           
\
+  .set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType)     
\
+  .set_attr<nnvm::FInplaceOption>("FInplaceOption",                            
\
+    [](const NodeAttrs& attrs){                                                
\
+      return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};                
\
+    })                                                                         
\
+  .add_argument("lhs", "NDArray-or-Symbol", "First input to the function")     
\
+  .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
+#else
+#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name)                
\
+  NNVM_REGISTER_OP(name)                                                       
\
+  .set_num_inputs(2)                                                           
\
+  .set_num_outputs(1)                                                          
\
+  .set_attr<nnvm::FListInputNames>("FListInputNames",                          
\
+    [](const NodeAttrs& attrs) {                                               
\
+      return std::vector<std::string>{"lhs", "rhs"};                           
\
+    })                                                                         
\
+  .set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)           
\
+  .set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType)     
\
+  .set_attr<nnvm::FInplaceOption>("FInplaceOption",                            
\
+    [](const NodeAttrs& attrs){                                                
\
+      return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};                
\
+    })                                                                         
\
+  .set_attr<FResourceRequest>("FResourceRequest",                              
\
+  [](const NodeAttrs& attrs) {                                                 
\
+    return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};          
\
+  })                                                                           
\
+  .add_argument("lhs", "NDArray-or-Symbol", "First input to the function")     
\
+  .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function")
+#endif
+
+MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add)
+#ifndef _WIN32
+.set_attr<FCompute>(
+  "FCompute<cpu>",
+  NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::plus, 
op::mshadow_op::mixed_plus,
+                                      op::mshadow_op::mixed_plus>)
+#else
+.set_attr<FCompute>(
+  "FCompute<cpu>",
+  NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::plus>)
+#endif
 .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseNone{"_backward_broadcast_add"});
 
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_subtract)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
op::mshadow_op::minus>)
+MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract)
+#ifndef _WIN32
+.set_attr<FCompute>(
+  "FCompute<cpu>",
+  NumpyBinaryBroadcastCompute<cpu, op::mshadow_op::minus, 
op::mshadow_op::mixed_minus,
+                              op::mshadow_op::mixed_rminus>)
+#else
+.set_attr<FCompute>(
+  "FCompute<cpu>",
+  NumpyBinaryBroadcastCompute<cpu, op::mshadow_op::minus>)
+#endif
 .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseNone{"_backward_broadcast_sub"});
 
-MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_multiply)
-.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, 
op::mshadow_op::mul>)
+MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply)
+#ifndef _WIN32
+.set_attr<FCompute>(
+  "FCompute<cpu>",
+  NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul, 
op::mshadow_op::mixed_mul,
+                                      op::mshadow_op::mixed_mul>)
+#else
+.set_attr<FCompute>(
+  "FCompute<cpu>",
+  NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul>)
+#endif
 .set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_broadcast_mul"});
 
 MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod)
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu 
b/src/operator/numpy/np_elemwise_broadcast_op.cu
index a682ec9..a0a277d 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.cu
+++ b/src/operator/numpy/np_elemwise_broadcast_op.cu
@@ -22,20 +22,47 @@
  * \file np_elemwise_broadcast_op.cu
  * \brief GPU Implementation of basic functions for elementwise binary 
broadcast operator.
  */
-#include "../tensor/elemwise_binary_broadcast_op.h"
-#include "../tensor/elemwise_binary_scalar_op.h"
+
+#include "./np_elemwise_broadcast_op.h"
 
 namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(_npi_add)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
op::mshadow_op::plus>);
+#ifndef _WIN32
+.set_attr<FCompute>(
+  "FCompute<gpu>",
+  NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::plus, 
op::mshadow_op::mixed_plus,
+                                      op::mshadow_op::mixed_plus>);
+#else
+.set_attr<FCompute>(
+  "FCompute<gpu>",
+  NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::plus>);
+#endif
 
 NNVM_REGISTER_OP(_npi_subtract)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
op::mshadow_op::minus>);
+#ifndef _WIN32
+.set_attr<FCompute>(
+  "FCompute<gpu>",
+  NumpyBinaryBroadcastCompute<gpu, op::mshadow_op::minus, 
op::mshadow_op::mixed_minus,
+                              op::mshadow_op::mixed_rminus>);
+#else
+.set_attr<FCompute>(
+  "FCompute<gpu>",
+  NumpyBinaryBroadcastCompute<gpu, op::mshadow_op::minus>);
+#endif
 
 NNVM_REGISTER_OP(_npi_multiply)
-.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
op::mshadow_op::mul>);
+#ifndef _WIN32
+.set_attr<FCompute>(
+  "FCompute<gpu>",
+  NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul, 
op::mshadow_op::mixed_mul,
+                                      op::mshadow_op::mixed_mul>);
+#else
+.set_attr<FCompute>(
+  "FCompute<gpu>",
+  NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul>);
+#endif
 
 NNVM_REGISTER_OP(_npi_mod)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, 
mshadow_op::mod>);
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h 
b/src/operator/numpy/np_elemwise_broadcast_op.h
new file mode 100644
index 0000000..1a4596f
--- /dev/null
+++ b/src/operator/numpy/np_elemwise_broadcast_op.h
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2019 by Contributors
+ * \file np_elemwise_binary_op.h
+ * \brief Function definition of elemwise and broadcast operators
+ */
+#ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
+#define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
+
+#include <vector>
+#include <string>
+
+#include "../tensor/elemwise_binary_broadcast_op.h"
+#include "../tensor/elemwise_binary_scalar_op.h"
+
+namespace mxnet {
+namespace op {
+
+inline void PrintErrorMessage(const std::string& op_name, const int dtype1, 
const int dtype2) {
+  LOG(FATAL) << "Operator " << op_name << " does not support combination of "
+             << common::dtype_string(dtype1) << " with " << 
common::dtype_string(dtype2)
+             << " yet...";
+}
+
+#ifndef _WIN32
+template<typename xpu, typename OP>
+void MixedAllRealBinaryElemwiseCompute(const std::string& op_name,
+                                       const OpContext& ctx,
+                                       const TBlob& lhs,
+                                       const TBlob& rhs,
+                                       const TBlob& out,
+                                       const OpReqType req) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(lhs.type_flag_, out.type_flag_);
+
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+
+  MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, {
+    const size_t size = (ElemwiseBinaryOp::minthree(out.Size(), lhs.Size(), 
rhs.Size())
+      + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
+    if (size == 0) return;
+
+    switch (lhs.type_flag_) {
+      case mshadow::kFloat32:
+      {
+        if (rhs.type_flag_ == mshadow::kFloat16) {
+          MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+            Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
+              s, size, out.dptr<float>(), rhs.dptr<mshadow::half::half_t>(),
+              lhs.dptr<float>());
+          });
+        } else {
+          PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
+        }
+        break;
+      }
+      case mshadow::kFloat64:
+      {
+        if (rhs.type_flag_ == mshadow::kFloat16) {
+          MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+            Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
+              s, size, out.dptr<double>(), rhs.dptr<mshadow::half::half_t>(),
+              lhs.dptr<double>());
+          });
+        } else if (rhs.type_flag_ == mshadow::kFloat32) {
+          MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+            Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
+              s, size, out.dptr<double>(), rhs.dptr<float>(),
+              lhs.dptr<double>());
+          });
+        } else {
+          PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
+        }
+        break;
+      }
+      default:
+      {
+        PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
+        break;
+      }
+    }
+  });
+}
+
+template<typename xpu, typename OP>
+void MixedIntRealBinaryElemwiseCompute(const OpContext& ctx,
+                                       const TBlob& lhs,
+                                       const TBlob& rhs,
+                                       const TBlob& out,
+                                       const OpReqType req) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(lhs.type_flag_, out.type_flag_);
+
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+
+  MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, FType, {
+    const size_t size = (ElemwiseBinaryOp::minthree(out.Size(), lhs.Size(), 
rhs.Size())
+      + DataType<FType>::kLanes - 1) / DataType<FType>::kLanes;
+    if (size == 0) return;
+
+    MXNET_INT_TYPE_SWITCH(rhs.type_flag_, IType, {
+      MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+        Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
+          s, size, out.dptr<FType>(), rhs.dptr<IType>(),
+          lhs.dptr<FType>());
+      });
+    });
+  });
+}
+
+template<typename xpu, typename LOP, typename ROP>
+void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs,
+                                const OpContext& ctx,
+                                const std::vector<TBlob>& inputs,
+                                const std::vector<OpReqType>& req,
+                                const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+
+  const TBlob& lhs = inputs[0];
+  const TBlob& rhs = inputs[1];
+  const TBlob& out = outputs[0];
+
+  if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
+    if (lhs.type_flag_ == out.type_flag_) {
+      MixedAllRealBinaryElemwiseCompute<xpu, ROP>(attrs.op->name, ctx, lhs, 
rhs, out, req[0]);
+    } else {
+      MixedAllRealBinaryElemwiseCompute<xpu, LOP>(attrs.op->name, ctx, rhs, 
lhs, out, req[0]);
+    }
+  } else if (common::is_float(lhs.type_flag_) || 
common::is_float(rhs.type_flag_)) {
+    if (lhs.type_flag_ == out.type_flag_) {
+      MixedIntRealBinaryElemwiseCompute<xpu, ROP>(ctx, lhs, rhs, out, req[0]);
+    } else {
+      MixedIntRealBinaryElemwiseCompute<xpu, LOP>(ctx, rhs, lhs, out, req[0]);
+    }
+  } else {
+    PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
+  }
+}
+
+template<typename xpu, typename OP>
+void MixedAllRealBinaryBroadcastCompute(const std::string& op_name,
+                                        const OpContext& ctx,
+                                        const TBlob& lhs,
+                                        const TBlob& rhs,
+                                        const TBlob& out,
+                                        const OpReqType req,
+                                        const int ndim,
+                                        const mxnet::TShape& new_oshape,
+                                        const mxnet::TShape& new_lshape,
+                                        const mxnet::TShape& new_rshape) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(lhs.type_flag_, out.type_flag_);
+
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+
+  BROADCAST_NDIM_SWITCH(ndim, NDim, {
+    mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
+    mshadow::Shape<NDim> lstride = 
mxnet_op::calc_stride(new_lshape.get<NDim>());
+    mshadow::Shape<NDim> rstride = 
mxnet_op::calc_stride(new_rshape.get<NDim>());
+    switch (lhs.type_flag_) {
+      case mshadow::kFloat32:
+      {
+        if (rhs.type_flag_ == mshadow::kFloat16) {
+          mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
+          template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, 
oshape,
+          rhs.dptr<mshadow::half::half_t>(), lhs.dptr<float>(), 
out.dptr<float>());
+        } else {
+          PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
+        }
+        break;
+      }
+      case mshadow::kFloat64:
+      {
+        if (rhs.type_flag_ == mshadow::kFloat16) {
+          mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
+          template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, 
oshape,
+          rhs.dptr<mshadow::half::half_t>(), lhs.dptr<double>(), 
out.dptr<double>());
+        } else if (rhs.type_flag_ == mshadow::kFloat32) {
+          mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
+          template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, 
oshape,
+          rhs.dptr<float>(), lhs.dptr<double>(), out.dptr<double>());
+        } else {
+          PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
+        }
+        break;
+      }
+      default:
+      {
+        PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
+        break;
+      }
+    }
+  });
+}
+#endif
+
+#ifndef _WIN32
+template<typename xpu, typename OP, typename LOP, typename ROP>
+#else
+template<typename xpu, typename OP>
+#endif
+void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
+                                 const OpContext& ctx,
+                                 const std::vector<TBlob>& inputs,
+                                 const std::vector<OpReqType>& req,
+                                 const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+
+  const TBlob& lhs = inputs[0];
+  const TBlob& rhs = inputs[1];
+  const TBlob& out = outputs[0];
+
+#ifndef _WIN32
+  mxnet::TShape new_lshape, new_rshape, new_oshape;
+  int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
+                                         &new_lshape, &new_rshape, 
&new_oshape);
+  if (!ndim) {
+    MixedBinaryElemwiseCompute<xpu, LOP, ROP>(attrs, ctx, inputs, req, 
outputs);
+  } else {
+    mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+    if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
+      if (lhs.type_flag_ == out.type_flag_) {
+        MixedAllRealBinaryBroadcastCompute<xpu, ROP>(
+          attrs.op->name, ctx, lhs, rhs, out, req[0], ndim, new_oshape, 
new_lshape, new_rshape);
+      } else {
+        MixedAllRealBinaryBroadcastCompute<xpu, LOP>(
+          attrs.op->name, ctx, rhs, lhs, out, req[0], ndim, new_oshape, 
new_rshape, new_lshape);
+      }
+    } else if (common::is_float(lhs.type_flag_) || 
common::is_float(rhs.type_flag_)) {
+      CHECK(lhs.type_flag_ == out.type_flag_ || rhs.type_flag_ == 
out.type_flag_)
+        << "One of the input type should be the same as the output";
+      BROADCAST_NDIM_SWITCH(ndim, NDim, {
+        mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
+        mshadow::Shape<NDim> lstride = 
mxnet_op::calc_stride(new_lshape.get<NDim>());
+        mshadow::Shape<NDim> rstride = 
mxnet_op::calc_stride(new_rshape.get<NDim>());
+        if (lhs.type_flag_ == out.type_flag_) {
+          MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, LType, {
+            MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, {
+              mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, ROP>, 
xpu>::
+              template LaunchEx(s, new_oshape.Size(), req[0], rstride, 
lstride, oshape,
+              rhs.dptr<RType>(), lhs.dptr<LType>(), out.dptr<LType>());
+            });
+          });
+        } else {
+          MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, RType, {
+            MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, {
+              mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, LOP>, 
xpu>::
+              template LaunchEx(s, new_oshape.Size(), req[0], lstride, 
rstride, oshape,
+              lhs.dptr<LType>(), rhs.dptr<RType>(), out.dptr<RType>());
+            });
+          });
+        }
+      });
+    } else {
+      PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
+    }
+  }
+#else
+  mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+  if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
+    TBlob temp_tblob;
+    // one is float, the other is bool
+    CHECK((out.type_flag_ == lhs.type_flag_) || (out.type_flag_ == 
rhs.type_flag_))
+      << "This case out type should be same as the float type";
+    if (lhs.type_flag_ == out.type_flag_) {
+      MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, {
+        Tensor<xpu, 1, LType> temp_tensor =
+          ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), 
s);
+        temp_tblob = TBlob(temp_tensor);
+      });
+      CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
+      BinaryBroadcastCompute<xpu, OP>(
+        attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
+    } else {
+      MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, {
+        Tensor<xpu, 1, RType> temp_tensor =
+          ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), 
s);
+        temp_tblob = TBlob(temp_tensor);
+      });
+      CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
+      BinaryBroadcastCompute<xpu, OP>(
+        attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
+    }
+  } else {
+    PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
+  }
+#endif
+}
+
+#ifndef _WIN32
+template<typename xpu, typename OP, typename LOP, typename ROP>
+#else
+template<typename xpu, typename OP>
+#endif
+void NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
+                                 const OpContext& ctx,
+                                 const std::vector<TBlob>& inputs,
+                                 const std::vector<OpReqType>& req,
+                                 const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+
+  const TBlob& lhs = inputs[0];
+  const TBlob& rhs = inputs[1];
+  const TBlob& out = outputs[0];
+
+  if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;
+
+  if (lhs.type_flag_ == rhs.type_flag_) {
+    BinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
+    return;
+  }
+
+#ifndef _WIN32
+  MixedBinaryBroadcastCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, 
outputs);
+#else
+  MixedBinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
+#endif
+}
+
+#ifndef _WIN32
+template<typename xpu, typename OP, typename LOP, typename ROP>
+#else
+template<typename xpu, typename OP>
+#endif
+void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs,
+                                         const OpContext& ctx,
+                                         const std::vector<TBlob>& inputs,
+                                         const std::vector<OpReqType>& req,
+                                         const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(inputs.size(), 2U);
+  CHECK_EQ(outputs.size(), 1U);
+
+  const TBlob& lhs = inputs[0];
+  const TBlob& rhs = inputs[1];
+  const TBlob& out = outputs[0];
+
+  if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;
+
+  if (lhs.type_flag_ == rhs.type_flag_) {
+    BinaryBroadcastComputeWithBool<xpu, OP>(attrs, ctx, inputs, req, outputs);
+    return;
+  }
+
+#ifndef _WIN32
+  MixedBinaryBroadcastCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, 
outputs);
+#else
+  MixedBinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
+#endif
+}
+
+template<typename xpu, typename LOP, typename ROP>
+void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
+                              const OpContext& ctx,
+                              const std::vector<TBlob>& inputs,
+                              const std::vector<OpReqType>& req,
+                              const std::vector<TBlob>& outputs) {
+  CHECK_EQ(inputs.size(), 3U);
+  CHECK_EQ(outputs.size(), 2U);
+
+  const TBlob& lhs = inputs[1];
+  const TBlob& rhs = inputs[2];
+  if (lhs.type_flag_ == rhs.type_flag_) {
+    BinaryBroadcastBackwardUseIn<xpu, LOP, ROP>(attrs, ctx, inputs, req, 
outputs);
+    return;
+  }
+
+  PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
+}
+
+}  // namespace op
+}  // namespace mxnet
+#endif  // MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
diff --git a/src/operator/numpy/np_true_divide.cc 
b/src/operator/numpy/np_true_divide.cc
index d2135be..1e46cc9 100644
--- a/src/operator/numpy/np_true_divide.cc
+++ b/src/operator/numpy/np_true_divide.cc
@@ -31,7 +31,7 @@ namespace op {
 int TrueDivideOutType(int ltype, int rtype) {
   if (common::is_float(ltype) && common::is_float(rtype)) {
     // If both inputs are float, return the one with the higher precision
-    return common::more_precise_type(ltype, rtype);
+    return common::get_more_precise_type(ltype, rtype);
   } else if (common::is_float(ltype) || common::is_float(rtype)) {
     // If only one of the inputs is float, return that float type
     return (common::is_float(ltype)) ? ltype : rtype;
@@ -126,7 +126,7 @@ NNVM_REGISTER_OP(_npi_rtrue_divide_scalar)
   })
 #endif
 .set_attr<FCompute>("FCompute<cpu>", TrueDivideScalarCompute<cpu, 
mshadow_op::rtrue_divide>)
-.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseNone{"_backward_rdiv_scalar"})
+.set_attr<nnvm::FGradient>("FGradient", 
ElemwiseGradUseIn{"_backward_rdiv_scalar"})
 .add_argument("data", "NDArray-or-Symbol", "source input")
 .add_argument("scalar", "float", "scalar input");
 
diff --git a/src/operator/operator_tune-inl.h b/src/operator/operator_tune-inl.h
index 1dbcf42..122ec04 100644
--- a/src/operator/operator_tune-inl.h
+++ b/src/operator/operator_tune-inl.h
@@ -124,7 +124,8 @@ class OperatorTune : public OperatorTuneByType<DType> {
     if (!initialized_) {
       initialized_ = true;
       // Generate some random data for calling the operator kernels
-      data_set_.reserve(0x100);
+      data_set_ =
+        std::unique_ptr<DType[]>(reinterpret_cast<DType*>(new char[0x100 * 
sizeof(DType)]));
       std::random_device rd;
       std::mt19937 gen(rd());
       if (!std::is_integral<DType>::value) {
@@ -136,7 +137,7 @@ class OperatorTune : public OperatorTuneByType<DType> {
             --n;
             continue;
           }
-          data_set_.emplace_back(val);
+          data_set_[n] = val;
         }
       } else {
         std::uniform_int_distribution<> dis(-128, 127);
@@ -147,7 +148,7 @@ class OperatorTune : public OperatorTuneByType<DType> {
             --n;
             continue;
           }
-          data_set_.emplace_back(val);
+          data_set_[n] = val;
         }
       }
       // Use this environment variable to generate new tuning statistics
@@ -517,7 +518,7 @@ class OperatorTune : public OperatorTuneByType<DType> {
   /*! \brief Number of passes to obtain an average */
   static constexpr duration_t OUTSIDE_COUNT = (1 << OUTSIDE_COUNT_SHIFT);
   /*! \brief Random data for timing operator calls */
-  static std::vector<DType> data_set_;
+  static std::unique_ptr<DType[]> data_set_;
   /*! \brief Operators tuned */
   static std::unordered_set<std::string> operator_names_;
   /*! \brief Arbitary object to modify in OMP loop */
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index d0642ee..633f630 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -39,7 +39,7 @@ double OperatorTuneBase::tuning_weight_scale_ = 0.0;
  */
 #define IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(__typ$) \
   template<> bool OperatorTune<__typ$>::initialized_ = false; \
-  template<> std::vector<__typ$> OperatorTune<__typ$>::data_set_ = {}; \
+  template<> std::unique_ptr<__typ$[]> OperatorTune<__typ$>::data_set_ = 
nullptr; \
   template<> volatile tune::TuningMode 
OperatorTuneByType<__typ$>::tuning_mode_ = tune::kAuto; \
   template<> volatile int OperatorTune<__typ$>::volatile_int_ = 9;  /* 
arbitrary number */ \
   template<> std::unordered_set<std::string> 
OperatorTune<__typ$>::operator_names_({}); \
@@ -314,10 +314,10 @@ 
IMPLEMENT_UNARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_logical_not);
 IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::nt);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::clip);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::clip);  // NOLINT()
-IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::plus);  // NOLINT()
-IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus);  // NOLINT()
-IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mul);  // NOLINT()
-IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::plus);  // 
NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::minus);  // 
NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::mul);  // 
NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::div);  // 
NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::true_divide);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus_sign);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus);  // NOLINT()
diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh 
b/src/operator/tensor/broadcast_reduce-inl.cuh
index fc863a9..360420a 100644
--- a/src/operator/tensor/broadcast_reduce-inl.cuh
+++ b/src/operator/tensor/broadcast_reduce-inl.cuh
@@ -619,8 +619,6 @@ void Reduce(Stream<gpu> *s, const TBlob& small, const 
OpReqType req,
   ReduceImplConfig<ndim> config =
     ConfigureReduceImpl<ndim, DType>(small.shape_, big.shape_, NULL, NULL);
   if (safe_acc) {
-    // TODO(haojin2): Use real-only type swtich for windows temporarily due to 
CI issues.
-#ifndef _WIN32
     MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
       typedef typename std::conditional<safe_acc, AType, DataType>::type 
AccType;
       MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
@@ -630,17 +628,6 @@ void Reduce(Stream<gpu> *s, const TBlob& small, const 
OpReqType req,
           stream, small, req, big, workspace, config);
       });
     });
-#else
-    MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, 
AType, {
-      typedef typename std::conditional<safe_acc, AType, DataType>::type 
AccType;
-      MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
-        typedef typename std::conditional<safe_acc, OType, DataType>::type 
OutType;
-        config = ConfigureReduceImpl<ndim, AccType>(small.shape_, big.shape_, 
NULL, NULL);
-        ReduceImpl<Reducer, ndim, AccType, DataType, OutType, OP>(
-          stream, small, req, big, workspace, config);
-      });
-    });
-#endif
   } else {
     ReduceImpl<Reducer, ndim, DType, DType, DType, OP>(stream, small, req, 
big, workspace, config);
   }
diff --git a/src/operator/tensor/broadcast_reduce-inl.h 
b/src/operator/tensor/broadcast_reduce-inl.h
index 415059a..e203da2 100644
--- a/src/operator/tensor/broadcast_reduce-inl.h
+++ b/src/operator/tensor/broadcast_reduce-inl.h
@@ -241,28 +241,15 @@ void Reduce(Stream<cpu>* s, const TBlob& small, const 
OpReqType req,
       N, M, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(),
       big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
   } else {
-    // TODO(haojin2): Use real-only type swtich for windows temporarily due to 
CI issues.
-#ifndef _WIN32
     MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, {
       typedef typename std::conditional<safe_acc, AType, DataType>::type 
AccType;
-      MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
+      MSHADOW_TYPE_SWITCH_WITH_BOOL(small.type_flag_, OType, {
         typedef typename std::conditional<safe_acc, OType, DataType>::type 
OutType;
         seq_reduce_compute<Reducer, ndim, AccType, DataType, OutType, OP>(
           N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<OutType>(),
           big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
       });
     });
-#else
-    MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, 
AType, {
-      typedef typename std::conditional<safe_acc, AType, DataType>::type 
AccType;
-      MSHADOW_TYPE_SWITCH(small.type_flag_, OType, {
-        typedef typename std::conditional<safe_acc, OType, DataType>::type 
OutType;
-        seq_reduce_compute<Reducer, ndim, AccType, DataType, OutType, OP>(
-          N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<OutType>(),
-          big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride);
-      });
-    });
-#endif
   }
 }
 
diff --git a/src/operator/tensor/broadcast_reduce_op.h 
b/src/operator/tensor/broadcast_reduce_op.h
index 414b606..27e2249 100644
--- a/src/operator/tensor/broadcast_reduce_op.h
+++ b/src/operator/tensor/broadcast_reduce_op.h
@@ -617,7 +617,7 @@ void ReduceAxesComputeImpl(const OpContext& ctx,
   BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape);
   Stream<xpu> *s = ctx.get_stream<xpu>();
   MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
-    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+    MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
       const TBlob in_data = inputs[0].reshape(src_shape);
       const TBlob out_data = outputs[0].reshape(dst_shape);
       BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, {
@@ -1045,8 +1045,8 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& 
attrs,
   mxnet::TShape src_shape, dst_shape;
   BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, 
&src_shape);
   Stream<xpu> *s = ctx.get_stream<xpu>();
-  MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, {
-    MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, {
+  MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, {
+    MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, {
       mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape;
       mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape;
       for (int i = 0; i < MXNET_SPECIAL_MAX_NDIM; ++i) {
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h 
b/src/operator/tensor/elemwise_binary_broadcast_op.h
index ad06df8..b48ed38 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -347,6 +347,9 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
   } else {
     if (req[0] != kNullOp) {
       mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+      if (outputs[0].type_flag_ == mshadow::kBool) {
+        LOG(FATAL) << "Operator " << attrs.op->name << " does not support 
boolean type";
+      }
       MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
         BROADCAST_NDIM_SWITCH(ndim, NDim, {
           mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
@@ -362,6 +365,35 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
 }
 
 template<typename xpu, typename OP>
+void BinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs,
+                                    const OpContext& ctx,
+                                    const std::vector<TBlob>& inputs,
+                                    const std::vector<OpReqType>& req,
+                                    const std::vector<TBlob>& outputs) {
+  if (outputs[0].shape_.Size() == 0U) return;
+  mxnet::TShape new_lshape, new_rshape, new_oshape;
+  int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, 
outputs[0].shape_,
+                                         &new_lshape, &new_rshape, 
&new_oshape);
+  if (!ndim) {
+    ElemwiseBinaryOp::ComputeWithBool<xpu, OP>(attrs, ctx, inputs, req, 
outputs);
+  } else {
+    if (req[0] != kNullOp) {
+      mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
+      MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, {
+        BROADCAST_NDIM_SWITCH(ndim, NDim, {
+          mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
+          mshadow::Shape<NDim> lstride = 
mxnet_op::calc_stride(new_lshape.get<NDim>());
+          mshadow::Shape<NDim> rstride = 
mxnet_op::calc_stride(new_rshape.get<NDim>());
+          mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
+          template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, 
oshape,
+          inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), 
outputs[0].dptr<DType>());
+        });
+      });
+    }
+  }
+}
+
+template<typename xpu, typename OP>
 void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& attrs,
                                  const OpContext& ctx,
                                  const std::vector<TBlob>& inputs,
diff --git a/src/operator/tensor/elemwise_binary_op.h 
b/src/operator/tensor/elemwise_binary_op.h
index 6f444ae..c046a28 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -99,11 +99,13 @@ class ElemwiseBinaryOp : public OpBase {
     return a1.var() == a2.var();
   }
 
+ public:
   /*! \brief Minimum of three */
   static MSHADOW_XINLINE size_t minthree(const size_t a, const size_t b, const 
size_t c) {
     return a < b ? (a < c ? a : c) : (b < c ? b : c);
   }
 
+ private:
   template<typename xpu, typename LOP, typename ROP, typename DType>
   static void BackwardUseNone_(const nnvm::NodeAttrs &attrs,
                                const OpContext &ctx,
@@ -483,6 +485,9 @@ class ElemwiseBinaryOp : public OpBase {
       Stream<xpu> *s = ctx.get_stream<xpu>();
       CHECK_EQ(inputs.size(), 2U);
       CHECK_EQ(outputs.size(), 1U);
+      if (outputs[0].type_flag_ == mshadow::kBool) {
+        LOG(FATAL) << "Operator " << attrs.op->name << " does not support 
boolean type";
+      }
       MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
         MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
           const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), 
inputs[1].Size())
@@ -498,6 +503,31 @@ class ElemwiseBinaryOp : public OpBase {
   }
 
   template<typename xpu, typename OP>
+  static void ComputeWithBool(const nnvm::NodeAttrs &attrs,
+                              const OpContext &ctx,
+                              const std::vector<TBlob> &inputs,
+                              const std::vector<OpReqType> &req,
+                              const std::vector<TBlob> &outputs) {
+    using namespace mxnet_op;
+    if (req[0] != kNullOp) {
+      Stream<xpu> *s = ctx.get_stream<xpu>();
+      CHECK_EQ(inputs.size(), 2U);
+      CHECK_EQ(outputs.size(), 1U);
+      MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
+        MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, {
+          const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), 
inputs[1].Size())
+          + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
+          if (size != 0) {
+            Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
+            outputs[0].dptr<DType>(),
+            inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
+          }
+        });
+      });
+    }
+  }
+
+  template<typename xpu, typename OP>
   static void ComputeLogic(const nnvm::NodeAttrs &attrs,
                            const OpContext &ctx,
                            const std::vector<TBlob> &inputs,
diff --git a/tests/nightly/JenkinsfileForBinaries 
b/tests/nightly/JenkinsfileForBinaries
index 48db445..2b55c05 100755
--- a/tests/nightly/JenkinsfileForBinaries
+++ b/tests/nightly/JenkinsfileForBinaries
@@ -18,9 +18,9 @@
 //
 //This is a Jenkinsfile for nightly tests. The format and some functions have 
been picked up from the top-level Jenkinsfile
 
-mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, 
lib/libtvmop.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
-mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, 
build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, 
build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, 
build/3rdparty/openmp/runtime/src/libomp.so'
-mx_lib_cpp_example_mkl = 'lib/libmxnet.so, lib/libmxnet.a, 
lib/libtvm_runtime.so, lib/libtvmop.so, 3rdparty/dmlc-core/libdmlc.a, 
3rdparty/tvm/nnvm/lib/libnnvm.a, build/cpp-package/example/imagenet_inference, 
lib/libmkldnn.so.1'
+mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, 
lib/libtvmop.so, lib/tvmop.conf, 3rdparty/dmlc-core/libdmlc.a, 
3rdparty/tvm/nnvm/lib/libnnvm.a'
+mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, 
build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, 
build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, 
build/3rdparty/openmp/runtime/src/libomp.so'
+mx_lib_cpp_example_mkl = 'lib/libmxnet.so, lib/libmxnet.a, 
lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, 
3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 
build/cpp-package/example/imagenet_inference, lib/libmkldnn.so.1'
 
 node('utility') {
   // Loading the utilities requires a node context unfortunately
diff --git 
a/tests/nightly/model_backwards_compatibility_check/JenkinsfileForMBCC 
b/tests/nightly/model_backwards_compatibility_check/JenkinsfileForMBCC
index 725261d..e419aa7 100644
--- a/tests/nightly/model_backwards_compatibility_check/JenkinsfileForMBCC
+++ b/tests/nightly/model_backwards_compatibility_check/JenkinsfileForMBCC
@@ -18,7 +18,7 @@
 //
 //This is a Jenkinsfile for the model backwards compatibility checker. The 
format and some functions have been picked up from the top-level Jenkinsfile.
 
-mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, 
lib/libtvm_runtime.so,lib/libtvmop.so, 3rdparty/dmlc-core/libdmlc.a, 
3rdparty/tvm/nnvm/lib/libnnvm.a'
+mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, 
lib/libtvmop.so, lib/tvmop.conf, 3rdparty/dmlc-core/libdmlc.a, 
3rdparty/tvm/nnvm/lib/libnnvm.a'
 
 node('restricted-utility') {
   // Loading the utilities requires a node context unfortunately
diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py
index 6adf935..5606eb1 100644
--- a/tests/python/gpu/test_fusion.py
+++ b/tests/python/gpu/test_fusion.py
@@ -28,6 +28,7 @@ from common import with_seed
 def check_fused_symbol(sym, **kwargs):
     inputs = sym.list_inputs()
     shapes = {inp : kwargs[inp].shape for inp in inputs}
+    ctx = kwargs.get('ctx', mx.gpu(0))
     # Double identity so that there is always something to fuse
     test_sym = mx.sym.Group([mx.sym.identity(mx.sym.identity(s)) for s in sym])
     rtol = {'float16' : 1e-2,
@@ -43,9 +44,9 @@ def check_fused_symbol(sym, **kwargs):
         for grad_req in ['write', 'add']:
             type_dict = {inp : dtype for inp in inputs}
             os.environ["MXNET_USE_FUSION"] = "0"
-            orig_exec = test_sym.simple_bind(ctx=mx.gpu(0), grad_req=grad_req, 
type_dict=type_dict, **shapes)
+            orig_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, 
type_dict=type_dict, **shapes)
             os.environ["MXNET_USE_FUSION"] = "1"
-            fused_exec = test_sym.simple_bind(ctx=mx.gpu(0), 
grad_req=grad_req, type_dict=type_dict, **shapes)
+            fused_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, 
type_dict=type_dict, **shapes)
             fwd_orig = orig_exec.forward(is_train=True, **data)
             out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig]
             orig_exec.backward(out_grads=out_grads)
@@ -218,6 +219,26 @@ def test_fusion():
     check_binary_ops()
     check_other_ops()
 
+@with_seed()
+def test_fusion_compiler_cache():
+    # Stresses the internal cache of CUfunctions by creating the same kernel 
multiple times and
+    # on multiple GPUs if available.
+    a = mx.sym.Variable('a')
+    b = mx.sym.Variable('b')
+    shape = rand_shape_2d()
+    arr1 = mx.random.uniform(shape=shape)
+    arr2 = mx.random.uniform(shape=shape)
+
+    # Invoke the same model twice, second time will exercise compile cache
+    check_fused_symbol(a+b, ctx=mx.gpu(0), a=arr1, b=arr2)
+    check_fused_symbol(a+b, ctx=mx.gpu(0), a=arr1, b=arr2)
+
+    # On multi-GPU systems, invoke the same model on other GPUs
+    num_gpus = mx.context.num_gpus()
+    if num_gpus > 1:
+        check_fused_symbol(a+b, ctx=mx.gpu(1), a=arr1, b=arr2)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_numpy_gluon.py 
b/tests/python/unittest/test_numpy_gluon.py
index 12e89a2..15914db 100644
--- a/tests/python/unittest/test_numpy_gluon.py
+++ b/tests/python/unittest/test_numpy_gluon.py
@@ -179,6 +179,23 @@ def test_np_get_constant():
         assert_almost_equal(out.asnumpy(), (x.asnumpy() + const_arr), 
atol=1e-5, rtol=1e-4, use_broadcast=False)
 
 
+@use_np
+def test_parameters_zero_grad():
+    for hybridize in [False, True]:
+        net = gluon.nn.HybridSequential()
+        for _ in range(5):
+            net.add(gluon.nn.Dense(10))
+        if hybridize:
+            net.hybridize()
+        net.initialize()
+        out = net(mx.np.ones((32, 8)))
+        for v in net.collect_params().values():
+            v.grad()[()] = 1
+        net.collect_params().zero_grad()
+        for v in net.collect_params().values():
+            assert_almost_equal(v.grad().asnumpy(), 
mx.np.zeros_like(v.grad()).asnumpy())
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_numpy_ndarray.py 
b/tests/python/unittest/test_numpy_ndarray.py
index 239f300..8e46f03 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -27,7 +27,7 @@ from mxnet import np, npx, autograd
 from mxnet.gluon import HybridBlock
 from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, 
rand_ndarray, retry, use_np
 from common import with_seed, TemporaryDirectory
-from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, 
assert_exception, is_op_runnable
+from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, 
assert_exception, is_op_runnable, collapse_sum_like
 from mxnet.ndarray.ndarray import py_slice
 from mxnet.base import integer_types
 import scipy.stats as ss
@@ -281,6 +281,62 @@ def test_np_ndarray_binary_element_wise_ops():
             '<=': _np.less_equal
         })
 
+    def _get_grad_func(op, scalar=None, reverse=False):
+        if op == '+':
+            if scalar is None:
+                return lambda ograd, x1, x2, out: (collapse_sum_like(ograd, 
x1.shape),
+                                                   collapse_sum_like(ograd, 
x2.shape))
+            elif not reverse:
+                return lambda ograd, x1, x2, out: ograd
+            else:
+                return lambda ograd, x1, x2, out: ograd
+        elif op == '-':
+            if scalar is None:
+                return lambda ograd, x1, x2, out: (collapse_sum_like(ograd, 
x1.shape),
+                                                   -collapse_sum_like(ograd, 
x2.shape))
+            elif not reverse:
+                return lambda ograd, x1, x2, out: ograd
+            else:
+                return lambda ograd, x1, x2, out: -ograd
+        elif op == '*':
+            if scalar is None:
+                return lambda ograd, x1, x2, out: (collapse_sum_like(ograd * 
x2, x1.shape),
+                                                   collapse_sum_like(ograd * 
x1, x2.shape))
+            elif not reverse:
+                return lambda ograd, x1, x2, out: ograd * x2
+            else:
+                return lambda ograd, x1, x2, out: ograd * x1
+        elif op == '/':
+            if scalar is None:
+                return lambda ograd, x1, x2, out: (collapse_sum_like(ograd / 
x2, x1.shape),
+                                                   collapse_sum_like(-x1 * 
ograd / (x2 * x2), x2.shape))
+            elif not reverse:
+                return lambda ograd, x1, x2, out: ograd / x2
+            else:
+                return lambda ograd, x1, x2, out: -x1 * ograd / (x2 * x2)
+        elif op == 'mod':
+            if scalar is None:
+                return lambda ograd, x1, x2, out: (collapse_sum_like(ograd, 
x1.shape),
+                                                   collapse_sum_like(-ograd * 
_np.floor(x1 / x2), x2.shape))
+            elif not reverse:
+                return lambda ograd, x1, x2, out: ograd
+            else:
+                return lambda ograd, x1, x2, out: -ograd * _np.floor(x1 / x2)
+        elif op == 'pow':
+            if scalar is None:
+                return lambda ograd, x1, x2, out: (collapse_sum_like(ograd * 
x2 * _np.power(x1, x2 - 1), x1.shape),
+                                                   collapse_sum_like(ograd * 
out * _np.log(x1), x2.shape))
+            elif not reverse:
+                return lambda ograd, x1, x2, out: ograd * x2 * _np.power(x1, 
x2 - 1)
+            else:
+                return lambda ograd, x1, x2, out: ograd * out * _np.log(x1)
+        elif op in ('==', '!=', '<', '<=', '>', '>='):
+            if scalar is None:
+                return lambda ograd, x1, x2, out: (_np.zeros_like(x1), 
_np.zeros_like(x2))
+            else:
+                return lambda ograd, x1, x2, out: _np.zeros_like(ograd)
+        return None
+
     def get_np_ret(x1, x2, op):
         return np_op_map[op](x1, x2)
 
@@ -364,13 +420,15 @@ def test_np_ndarray_binary_element_wise_ops():
             mx_input1 = abs(_np.random.uniform()) + 1
             np_input1 = mx_input1
         else:
-            mx_input1 = rand_ndarray(shape1, dtype=dtype).abs() + 1
+            mx_input1 = (rand_ndarray(shape1, dtype=dtype).abs() + 
1).as_np_ndarray()
+            mx_input1.attach_grad()
             np_input1 = mx_input1.asnumpy()
         if shape2 is None:
             mx_input2 = abs(_np.random.uniform()) + 1
             np_input2 = mx_input2
         else:
-            mx_input2 = rand_ndarray(shape2, dtype=dtype).abs() + 1
+            mx_input2 = (rand_ndarray(shape2, dtype=dtype).abs() + 
1).as_np_ndarray()
+            mx_input2.attach_grad()
             np_input2 = mx_input2.asnumpy()
 
         scalar = None
@@ -382,7 +440,9 @@ def test_np_ndarray_binary_element_wise_ops():
             scalar = mx_input1
             reverse = True
 
+        grad_func = _get_grad_func(op, scalar, reverse)
         np_out = get_np_ret(np_input1, np_input2, op)
+        ograd = _np.ones_like(np_out)
         for hybridize in [True, False]:
             if scalar is None:
                 get_mx_ret_np = TestBinaryElementWiseOp(op)
@@ -390,26 +450,49 @@ def test_np_ndarray_binary_element_wise_ops():
                 if hybridize:
                     get_mx_ret_np.hybridize()
                     get_mx_ret_classic.hybridize()
-                mx_out = get_mx_ret_np(mx_input1.as_np_ndarray(), 
mx_input2.as_np_ndarray())
+                if grad_func is None:
+                    mx_out = get_mx_ret_np(mx_input1, mx_input2)
+                else:
+                    with mx.autograd.record():
+                        mx_out = get_mx_ret_np(mx_input1, mx_input2)
+                    mx_out.backward()
                 assert type(mx_out) == np.ndarray
-                assert np_out.shape == mx_out.shape
                 if op in logic_ops:
                     assert np_out.dtype == mx_out.dtype
-                assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, 
rtol=1e-5)
+                assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, 
rtol=1e-5, use_broadcast=False)
+
+                if grad_func is not None:
+                    x1_grad_expected, x2_grad_expected = grad_func(ograd, 
np_input1, np_input2, np_out)
+                    assert_almost_equal(mx_input1.grad.asnumpy(), 
x1_grad_expected, atol=1e-5, rtol=1e-3,
+                                        use_broadcast=False)
+                    assert_almost_equal(mx_input2.grad.asnumpy(), 
x2_grad_expected, atol=1e-5, rtol=1e-3,
+                                        use_broadcast=False)
             else:
                 get_mx_ret = TestBinaryElementWiseOp(op, scalar=scalar, 
reverse=reverse)
                 if hybridize:
                     get_mx_ret.hybridize()
                 if reverse:
-                    mx_out = get_mx_ret(mx_input2.as_np_ndarray())
-                    assert type(mx_out) == np.ndarray
+                    mx_input = mx_input2
                 else:
-                    mx_out = get_mx_ret(mx_input1.as_np_ndarray())
-                    assert type(mx_out) == np.ndarray
-                assert np_out.shape == mx_out.shape
+                    mx_input = mx_input1
+
+                if grad_func is None:
+                    mx_out = get_mx_ret(mx_input)
+                else:
+                    with mx.autograd.record():
+                        mx_out = get_mx_ret(mx_input)
+                    mx_out.backward()
+                assert type(mx_out) == np.ndarray
+
                 if op in logic_ops:
                     assert np_out.dtype == mx_out.dtype
-                assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, 
rtol=1e-5)
+                assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, 
rtol=1e-5, use_broadcast=False)
+
+                # check grad
+                if grad_func is not None:
+                    x_grad_expected = grad_func(ograd, np_input1, np_input2, 
np_out)
+                    assert_almost_equal(mx_input.grad.asnumpy(), 
x_grad_expected, atol=1e-5, rtol=1e-3,
+                                        use_broadcast=False)
 
     dtypes = [_np.float32, _np.float64, None]
     ops = np_op_map.keys()
diff --git a/tests/python/unittest/test_numpy_op.py 
b/tests/python/unittest/test_numpy_op.py
index c1a6ed5..9a36e06 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -614,52 +614,83 @@ def test_np_mean():
     def is_int(dtype):
         return 'int' in dtype
 
+    is_windows = sys.platform.startswith('win')
     in_data_dim = random.choice([2, 3, 4])
     shape = rand_shape_nd(in_data_dim, dim=3)
     acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 
'float64',
-                'int8': 'int32', 'int32': 'int64', 'int64': 'int64'}
+                'bool': 'int64', 'int8': 'int32', 'int32': 'int64', 'int64': 
'int64'}
+    ft_types = ['float16', 'float32', 'float64']
+    it_types = ['bool', 'int8', 'int32', 'int64']
     for hybridize in [False, True]:
         for keepdims in [True, False]:
             for axis in ([i for i in range(in_data_dim)] + [(), None]):
-                for itype in ['float16', 'float32', 'float64']:
-                    for dtype in ['float16', 'float32', 'float64']:
-                        if is_int(dtype) and not is_int(itype):
-                            continue
-                        # test gluon
-                        test_mean = TestMean(axis=axis, dtype=dtype, 
keepdims=keepdims)
-                        if hybridize:
-                            test_mean.hybridize()
-                        if is_int(itype):
-                            x = _np.random.randint(-128, 128, shape, 
dtype=itype)
-                            x = mx.nd.array(x, dtype=itype)
-                        else:
-                            x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, 
dtype=itype)
-                        x = x.as_np_ndarray()
-                        x.attach_grad()
+                for itype, dtype in itertools.product(ft_types, [None] + 
ft_types + it_types):
+                    if dtype == 'bool':
+                        continue
+                    # test gluon
+                    test_mean = TestMean(axis=axis, dtype=dtype, 
keepdims=keepdims)
+                    if hybridize:
+                        test_mean.hybridize()
+                    x = np.random.uniform(-1.0, 1.0, size=shape).astype(itype)
+                    x = x.as_np_ndarray()
+                    x.attach_grad()
 
-                        expected_ret = _np.mean(x.asnumpy(), axis=axis, 
dtype=acc_type[itype], keepdims=keepdims)
-                        expected_ret = expected_ret.astype(dtype)
-                        with mx.autograd.record():
-                            y = test_mean(x)
-                        assert y.shape == expected_ret.shape
-                        assert_almost_equal(y.asnumpy(), expected_ret, 
rtol=1e-3 if dtype == 'float16' else 1e-3,
-                                            atol=1e-5 if dtype == 'float16' 
else 1e-5)
+                    expected_ret = _np.mean(x.asnumpy(), axis=axis, 
dtype=acc_type[itype], keepdims=keepdims)
+                    expected_ret = expected_ret.astype(dtype)
+                    with mx.autograd.record():
+                        y = test_mean(x)
+                    assert y.shape == expected_ret.shape
+                    assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 
if dtype == 'float16' else 1e-3,
+                                        atol=1e-5 if dtype == 'float16' else 
1e-5)
 
-                        y.backward()
-                        N = x.size / y.size
-                        assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, 
dtype=x.dtype) / N)
+                    y.backward()
+                    N = x.size / y.size
+                    assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, 
dtype=x.dtype) / N)
 
-                        # test numeric
-                        if itype == 'float32' and dtype == 'float32':
-                            x_sym = mx.sym.Variable("x").as_np_ndarray()
-                            mx_sym = mx.sym.np.mean(x_sym, axis=axis, 
dtype=dtype, keepdims=keepdims).as_nd_ndarray()
-                            check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
-                                                   numeric_eps=1e-3, 
rtol=1e-3, atol=1e-4, dtype=_np.float32)
+                    # test numeric
+                    if itype == 'float32' and dtype == 'float32':
+                        x_sym = mx.sym.Variable("x").as_np_ndarray()
+                        mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, 
keepdims=keepdims).as_nd_ndarray()
+                        check_numeric_gradient(mx_sym, [x.as_nd_ndarray()],
+                                               numeric_eps=1e-3, rtol=1e-3, 
atol=1e-4, dtype=_np.float32)
 
-                        # test imperative
-                        mx_out = np.mean(x, axis=axis, dtype=dtype, 
keepdims=keepdims)
-                        np_out = _np.mean(x.asnumpy(), axis=axis, 
dtype=acc_type[itype], keepdims=keepdims).astype(dtype)
-                        assert_almost_equal(mx_out.asnumpy(), np_out, 
rtol=1e-3, atol=1e-5)
+                    # test imperative
+                    mx_out = np.mean(x, axis=axis, dtype=dtype, 
keepdims=keepdims)
+                    np_out = _np.mean(x.asnumpy(), axis=axis, 
dtype=acc_type[itype], keepdims=keepdims).astype(dtype)
+                    assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, 
atol=1e-5)
+
+                for itype, dtype in itertools.product(it_types, [None] + 
ft_types + it_types):
+                    if dtype == 'bool':
+                        continue
+                    # test gluon
+                    test_mean = TestMean(axis=axis, dtype=dtype, 
keepdims=keepdims)
+                    if hybridize:
+                        test_mean.hybridize()
+
+                    if itype == 'bool':
+                        x = np.array(_np.random.uniform(size=shape) > 0.5)
+                    else:
+                        x = np.random.uniform(-128, 127, 
size=shape).astype(itype)
+
+                    expected_ret = _np.mean(x.asnumpy(), axis=axis, 
dtype=dtype, keepdims=keepdims)
+
+                    if itype == 'bool':
+                        if is_op_runnable() and (not is_windows) and dtype not 
in ['float16', 'int8']:  # special handling of boolean ndarray
+                            y = test_mean(x)
+                            assert y.shape == expected_ret.shape
+                            assert_almost_equal(y.asnumpy(), expected_ret, 
rtol=1e-3 if dtype == 'float16' else 1e-3,
+                                                atol=1e-5 if dtype == 
'float16' else 1e-5)
+                        continue
+
+                    y = test_mean(x)
+                    assert y.shape == expected_ret.shape
+                    assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 
if dtype == 'float16' else 1e-3,
+                                        atol=1e-5 if dtype == 'float16' else 
1e-5)
+
+                    # test imperative
+                    mx_out = np.mean(x, axis=axis, dtype=dtype, 
keepdims=keepdims)
+                    np_out = _np.mean(x.asnumpy(), axis=axis, dtype=dtype, 
keepdims=keepdims).astype(dtype)
+                    assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, 
atol=1e-5)
 
 
 @with_seed()
@@ -1572,8 +1603,8 @@ def test_np_binary_funcs():
                                             rtol=1e-1, atol=1e-2, 
equal_nan=True, use_broadcast=False)
                         if rgrads is None:
                             assert_almost_equal(mx_test_x2.grad.asnumpy(),
-                                               
collapse_sum_like(rgrad(y.asnumpy(), np_test_x2, np_test_x1), mx_test_x2.shape),
-                                               rtol=1e-1, atol=1e-2, 
equal_nan=True, use_broadcast=False)
+                                                
collapse_sum_like(rgrad(y.asnumpy(), np_test_x2, np_test_x1), mx_test_x2.shape),
+                                                rtol=1e-1, atol=1e-2, 
equal_nan=True, use_broadcast=False)
                         else:
                             assert_almost_equal(mx_test_x2.grad.asnumpy(),
                                                 
collapse_sum_like(rgrad(y.asnumpy(), np_test_x1, np_test_x2), mx_test_x2.shape),
@@ -1594,7 +1625,6 @@ def test_np_binary_funcs():
                 assertRaises(NotImplementedError, getattr(np, func), 
mx_test_x1, mx_test_x2,  order='C')
                 assertRaises(NotImplementedError, getattr(np, func), 
mx_test_x1, mx_test_x2,  order='mxnet')
 
-
     funcs = {
         'add': (-1.0, 1.0, [lambda y, x1, x2: _np.ones(y.shape)], None),
         'subtract':
@@ -1603,7 +1633,7 @@ def test_np_binary_funcs():
         'multiply': (-1.0, 1.0, [lambda y, x1, x2: _np.broadcast_to(x2, 
y.shape)],
                                 [lambda y, x1, x2: _np.broadcast_to(x1, 
y.shape)]),
         'divide': (0.1, 1.0, [lambda y, x1, x2: _np.ones(y.shape) / x2],
-                               [lambda y, x1, x2: -x1 / (x2 * x2)]),
+                   [lambda y, x1, x2: -x1 / (x2 * x2)]),
         'mod': (1.0, 10.0,
                 [lambda y, x1, x2: _np.ones(y.shape),
                  lambda y, x1, x2: _np.zeros(y.shape)],
@@ -1652,6 +1682,125 @@ def test_np_binary_funcs():
 
 @with_seed()
 @use_np
+def test_np_mixed_precision_binary_funcs():
+    def check_mixed_precision_binary_func(func, low, high, lshape, rshape, 
ltype, rtype):
+        class TestMixedBinary(HybridBlock):
+            def __init__(self, func):
+                super(TestMixedBinary, self).__init__()
+                self._func = func
+
+            def hybrid_forward(self, F, a, b, *args, **kwargs):
+                return getattr(F.np, self._func)(a, b)
+
+        np_func = getattr(_np, func)
+        mx_func = TestMixedBinary(func)
+        np_test_x1 = _np.random.uniform(low, high, lshape).astype(ltype)
+        np_test_x2 = _np.random.uniform(low, high, rshape).astype(rtype)
+        mx_test_x1 = mx.numpy.array(np_test_x1, dtype=ltype)
+        mx_test_x2 = mx.numpy.array(np_test_x2, dtype=rtype)
+        rtol = 1e-2 if ltype is np.float16 or rtype is np.float16 else 1e-3
+        atol = 1e-4 if ltype is np.float16 or rtype is np.float16 else 1e-5
+        for hybridize in [True, False]:
+            if hybridize:
+                mx_func.hybridize()
+            np_out = np_func(np_test_x1, np_test_x2)
+            with mx.autograd.record():
+                y = mx_func(mx_test_x1, mx_test_x2)
+            assert y.shape == np_out.shape
+            assert_almost_equal(y.asnumpy(), np_out.astype(y.dtype), 
rtol=rtol, atol=atol,
+                                use_broadcast=False, equal_nan=True)
+
+        np_out = getattr(_np, func)(np_test_x1, np_test_x2)
+        mx_out = getattr(mx.np, func)(mx_test_x1, mx_test_x2)
+        assert mx_out.shape == np_out.shape
+        assert_almost_equal(mx_out.asnumpy(), np_out.astype(mx_out.dtype), 
rtol=rtol, atol=atol,
+                            use_broadcast=False, equal_nan=True)
+
+    funcs = {
+        'add': (-1.0, 1.0),
+        'subtract': (-1.0, 1.0),
+        'multiply': (-1.0, 1.0),
+    }
+
+    shape_pairs = [((3, 2), (3, 2)),
+                   ((3, 2), (3, 1)),
+                   ((3, 1), (3, 0)),
+                   ((0, 2), (1, 2)),
+                   ((2, 3, 4), (3, 1)),
+                   ((2, 3), ()),
+                   ((), (2, 3))]
+
+    itypes = [np.bool, np.int8, np.int32, np.int64]
+    ftypes = [np.float16, np.float32, np.float64]
+    for func, func_data in funcs.items():
+        low, high = func_data
+        for lshape, rshape in shape_pairs:
+            for type1, type2 in itertools.product(itypes, ftypes):
+                check_mixed_precision_binary_func(func, low, high, lshape, 
rshape, type1, type2)
+                check_mixed_precision_binary_func(func, low, high, lshape, 
rshape, type2, type1)
+
+            for type1, type2 in itertools.product(ftypes, ftypes):
+                if type1 == type2:
+                    continue
+                check_mixed_precision_binary_func(func, low, high, lshape, 
rshape, type1, type2)
+
+
+@with_seed()
+@use_np
+def test_np_boolean_binary_funcs():
+    def check_boolean_binary_func(func, mx_x1, mx_x2):
+        class TestBooleanBinary(HybridBlock):
+            def __init__(self, func):
+                super(TestBooleanBinary, self).__init__()
+                self._func = func
+
+            def hybrid_forward(self, F, a, b, *args, **kwargs):
+                return getattr(F.np, self._func)(a, b)
+
+        np_x1 = mx_x1.asnumpy()
+        np_x2 = mx_x2.asnumpy()
+        np_func = getattr(_np, func)
+        mx_func = TestBooleanBinary(func)
+        for hybridize in [True, False]:
+            if hybridize:
+                mx_func.hybridize()
+            np_out = np_func(np_x1, np_x2)
+            with mx.autograd.record():
+                y = mx_func(mx_x1, mx_x2)
+            assert y.shape == np_out.shape
+            assert_almost_equal(y.asnumpy(), np_out.astype(y.dtype), 
rtol=1e-3, atol=1e-20,
+                                use_broadcast=False, equal_nan=True)
+
+        np_out = getattr(_np, func)(np_x1, np_x2)
+        mx_out = getattr(mx.np, func)(mx_x1, mx_x2)
+        assert mx_out.shape == np_out.shape
+        assert_almost_equal(mx_out.asnumpy(), np_out.astype(mx_out.dtype), 
rtol=1e-3, atol=1e-20,
+                            use_broadcast=False, equal_nan=True)
+
+
+    funcs = [
+        'add',
+        'multiply',
+        'true_divide',
+    ]
+
+    shape_pairs = [((3, 2), (3, 2)),
+                   ((3, 2), (3, 1)),
+                   ((3, 1), (3, 0)),
+                   ((0, 2), (1, 2)),
+                   ((2, 3, 4), (3, 1)),
+                   ((2, 3), ()),
+                   ((), (2, 3))]
+
+    for lshape, rshape in shape_pairs:
+        for func in funcs:
+            x1 = np.array(_np.random.uniform(size=lshape) > 0.5)
+            x2 = np.array(_np.random.uniform(size=rshape) > 0.5)
+            check_boolean_binary_func(func, x1, x2)
+
+
+@with_seed()
+@use_np
 def test_npx_relu():
     def np_relu(x):
         return _np.maximum(x, 0.0)

Reply via email to