This is an automated email from the ASF dual-hosted git repository. ptrendx pushed a commit to branch v1.6.x in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.6.x by this push: new 9abb151 Backport to 1.6 (#16773, #16781, #16783, #16716, #16699, #16728, #16769, #16792) (#16832) 9abb151 is described below commit 9abb1513982aa0e43cd4bda0568651fa28ec0938 Author: Przemyslaw Tredak <ptre...@nvidia.com> AuthorDate: Fri Nov 15 17:50:38 2019 -0800 Backport to 1.6 (#16773, #16781, #16783, #16716, #16699, #16728, #16769, #16792) (#16832) * Fix nightly build (#16773) * Remove dependency on tvmop.conf * Fix binaries dependencies for ni nightly * Add comments * Update tvmop.py * Fix rebase * Fix (#16781) * Speed fused_op compilation by caching ptx and jit-compiled functions (#16783) * [Numpy] Fix collect_params().zero_grad() in gluon numpy interface (#16716) * fix zero_grad * Update parameter.py * add test * fix * Mixed data type binary ops (#16699) * support mixed-precision binary operations * improvement for documentations and error messages * Support boolean elemwise/broadcast binary add, multiply and true_divide (#16728) * support pure boolean elemwise/broadcast binary op * switch to unique_tpr * fix the test error * Fix rtrue_divide grad (#16769) * Fix rtrue_divide_scalar * More tests * Fix numpy-compatible mean output type for integer inputs (#16792) * fix mean output type for integer inputs * enable for windows --- python/mxnet/gluon/parameter.py | 12 +- python/mxnet/ndarray/numpy/_op.py | 40 ++ python/mxnet/numpy/multiarray.py | 40 ++ python/mxnet/tvmop.py | 14 +- src/common/utils.h | 30 +- src/executor/pointwise_fusion_pass.cc | 27 +- src/ndarray/ndarray_function-inl.h | 2 +- src/operator/fusion/fused_op-inl.h | 4 +- src/operator/fusion/fused_op.cc | 50 ++- src/operator/fusion/fused_op.cu | 225 +++++++----- src/operator/fusion/fused_op.h | 43 ++- src/operator/mshadow_op.h | 114 +++++- src/operator/mxnet_op.h | 14 +- src/operator/numpy/np_broadcast_reduce_op.h | 12 +- src/operator/numpy/np_broadcast_reduce_op_value.cc | 11 +- src/operator/numpy/np_elemwise_broadcast_op.cc | 97 ++++- src/operator/numpy/np_elemwise_broadcast_op.cu | 37 +- src/operator/numpy/np_elemwise_broadcast_op.h | 404 +++++++++++++++++++++ src/operator/numpy/np_true_divide.cc | 4 +- src/operator/operator_tune-inl.h | 9 +- src/operator/operator_tune.cc | 10 +- src/operator/tensor/broadcast_reduce-inl.cuh | 13 - src/operator/tensor/broadcast_reduce-inl.h | 15 +- src/operator/tensor/broadcast_reduce_op.h | 6 +- src/operator/tensor/elemwise_binary_broadcast_op.h | 32 ++ src/operator/tensor/elemwise_binary_op.h | 30 ++ tests/nightly/JenkinsfileForBinaries | 6 +- .../JenkinsfileForMBCC | 2 +- tests/python/gpu/test_fusion.py | 25 +- tests/python/unittest/test_numpy_gluon.py | 17 + tests/python/unittest/test_numpy_ndarray.py | 107 +++++- tests/python/unittest/test_numpy_op.py | 229 ++++++++++-- 32 files changed, 1390 insertions(+), 291 deletions(-) diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py index 957dc2c..067a357 100644 --- a/python/mxnet/gluon/parameter.py +++ b/python/mxnet/gluon/parameter.py @@ -27,7 +27,6 @@ __all__ = ['DeferredInitializationError', 'Parameter', 'Constant', from collections import OrderedDict, defaultdict import warnings import numpy as np -import mxnet as mx from ..base import mx_real_t, MXNetError from .. import symbol, ndarray, initializer, context @@ -896,15 +895,20 @@ class ParameterDict(object): continue for g in p.list_grad(): if g.stype == 'row_sparse': - mx.ndarray.zeros_like(g, out=g) + ndarray.zeros_like(g, out=g) else: arrays[g.context].append(g) if len(arrays) == 0: return - for arr in arrays.values(): - mx.nd.reset_arrays(*arr, num_arrays=len(arr)) + if is_np_array(): + for arr in arrays.values(): + for ele in arr: + ele[()] = 0 + else: + for arr in arrays.values(): + ndarray.reset_arrays(*arr, num_arrays=len(arr)) def reset_ctx(self, ctx): """Re-assign all Parameters to other contexts. diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index c215159..ff404a7 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -522,6 +522,14 @@ def add(x1, x2, out=None, **kwargs): ------- add : ndarray or scalar The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), not supported yet. """ return _ufunc_helper(x1, x2, _npi.add, _np.add, _npi.add_scalar, None, out) @@ -548,6 +556,14 @@ def subtract(x1, x2, out=None, **kwargs): ------- subtract : ndarray or scalar The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), not supported yet. """ return _ufunc_helper(x1, x2, _npi.subtract, _np.subtract, _npi.subtract_scalar, _npi.rsubtract_scalar, out) @@ -575,6 +591,14 @@ def multiply(x1, x2, out=None, **kwargs): out : ndarray or scalar The multiplication of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), not supported yet. """ return _ufunc_helper(x1, x2, _npi.multiply, _np.multiply, _npi.multiply_scalar, None, out) @@ -602,6 +626,14 @@ def divide(x1, x2, out=None, **kwargs): ------- out : ndarray or scalar This is a scalar if both x1 and x2 are scalars. + + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), the output is of float32 type. """ return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar, _npi.rtrue_divide_scalar, out) @@ -632,6 +664,14 @@ def true_divide(x1, x2, out=None): ------- out : ndarray or scalar This is a scalar if both x1 and x2 are scalars. + + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), the output is of float32 type. """ return _ufunc_helper(x1, x2, _npi.true_divide, _np.divide, _npi.true_divide_scalar, _npi.rtrue_divide_scalar, out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 85bd2ac..c623f67 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -2399,6 +2399,14 @@ def add(x1, x2, out=None, **kwargs): add : ndarray or scalar The sum of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), not supported yet. + Examples -------- >>> np.add(1.0, 4.0) @@ -2437,6 +2445,14 @@ def subtract(x1, x2, out=None, **kwargs): subtract : ndarray or scalar The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), not supported yet. + Examples -------- >>> np.subtract(1.0, 4.0) @@ -2473,6 +2489,14 @@ def multiply(x1, x2, out=None, **kwargs): out : ndarray or scalar The difference of x1 and x2, element-wise. This is a scalar if both x1 and x2 are scalars. + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), not supported yet. + Examples -------- >>> np.multiply(2.0, 4.0) @@ -2511,6 +2535,14 @@ def divide(x1, x2, out=None, **kwargs): out : ndarray or scalar This is a scalar if both x1 and x2 are scalars. + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), the output is of float32 type. + Examples -------- >>> np.true_divide(x, 4) @@ -2545,6 +2577,14 @@ def true_divide(x1, x2, out=None): out : ndarray or scalar This is a scalar if both x1 and x2 are scalars. + Notes + ----- + This operator now supports automatic type promotion. The resulting type will be determined + according to the following rules: + * If both inputs are of floating number types, the output is the more precise type. + * If only one of the inputs is floating number type, the result is that type. + * If both inputs are of integer types (including boolean), the output is of float32 type. + Examples -------- >>> x = np.arange(5) diff --git a/python/mxnet/tvmop.py b/python/mxnet/tvmop.py index 9ec278a..29e6f0d 100644 --- a/python/mxnet/tvmop.py +++ b/python/mxnet/tvmop.py @@ -21,6 +21,7 @@ from .runtime import Features if Features().is_enabled("TVM_OP"): import json + import logging from ._ctypes.space import _set_tvm_op_config from .base import check_call, _LIB, c_str @@ -31,7 +32,12 @@ if Features().is_enabled("TVM_OP"): check_call(_LIB.MXLoadTVMOp(c_str(_LIB_TVM_OP[0]))) # op sch config - _CONF_TVM_OP = find_conf_path("tvmop") - with open(_CONF_TVM_OP[0], "r") as f: - ret = ConfigSpaces.from_json_dict(json.load(f)) - _set_tvm_op_config(ret) + try: + _CONF_TVM_OP = find_conf_path("tvmop") + except RuntimeError as e: + logging.warning("TVM config file missing, falling back to default schedule", exc_info=True) + else: + logging.info("TVM op config has been loaded") + with open(_CONF_TVM_OP[0], "r") as f: + ret = ConfigSpaces.from_json_dict(json.load(f)) + _set_tvm_op_config(ret) diff --git a/src/common/utils.h b/src/common/utils.h index b919cb3..0e3e354 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -365,6 +365,30 @@ inline bool ContainsStorageType(const std::vector<int>& ndstypes, return false; } +inline std::string dtype_string(const int dtype) { + switch (dtype) { + case mshadow::kFloat32: + return "float"; + case mshadow::kFloat64: + return "double"; + case mshadow::kFloat16: + return "half"; + case mshadow::kUint8: + return "unsigned char"; + case mshadow::kInt8: + return "char"; + case mshadow::kInt32: + return "int"; + case mshadow::kInt64: + return "long long"; + case mshadow::kBool: + return "bool"; + default: + LOG(FATAL) << "Unknown type enum " << dtype; + } + return "unknown"; +} + /*! \brief get string representation of dispatch_mode */ inline std::string dispatch_mode_string(const DispatchMode x) { switch (x) { @@ -842,7 +866,7 @@ inline bool is_float(const int dtype) { return dtype == mshadow::kFloat32 || dtype == mshadow::kFloat64 || dtype == mshadow::kFloat16; } -inline int more_precise_type(const int type1, const int type2) { +inline int get_more_precise_type(const int type1, const int type2) { if (type1 == type2) return type1; if (is_float(type1) && is_float(type2)) { if (type1 == mshadow::kFloat64 || type2 == mshadow::kFloat64) { @@ -870,12 +894,12 @@ inline int more_precise_type(const int type1, const int type2) { return mshadow::kInt8; } -inline int np_binary_out_type(const int type1, const int type2) { +inline int np_binary_out_infer_type(const int type1, const int type2) { if ((type1 == mshadow::kUint8 && type2 == mshadow::kInt8) || (type1 == mshadow::kInt8 && type2 == mshadow::kUint8)) { return mshadow::kInt32; } - return more_precise_type(type1, type2); + return get_more_precise_type(type1, type2); } } // namespace common diff --git a/src/executor/pointwise_fusion_pass.cc b/src/executor/pointwise_fusion_pass.cc index c6e2405..6fe2140 100644 --- a/src/executor/pointwise_fusion_pass.cc +++ b/src/executor/pointwise_fusion_pass.cc @@ -83,15 +83,7 @@ namespace { auto node = nnvm::Node::Create(); subgraph_sym.outputs = subgraph.outputs; node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym)); - std::ostringstream name_oss; - // the name of the new node will be the concatenation of all the node names in the subgraph - DFSVisit(subgraph.outputs, [&name_oss](const nnvm::NodePtr n) { - if (n->op() != nullptr) - name_oss << n->op()->name << "_"; - }); - auto subgraph_name = name_oss.str(); - subgraph_name.pop_back(); - node->attrs.name = subgraph_name; + node->attrs.name = "FusedOp"; node->attrs.dict["num_inputs"] = std::to_string(inputs_size); node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size()); node->attrs.op = Op::Get("_FusedOp"); @@ -152,7 +144,8 @@ Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& sub auto it = node->control_deps.begin(); static auto& is_fusion = Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper"); std::vector<nnvm::NodePtr> new_control_deps; - while (it != node->control_deps.end()) { + // Use the first control dependency to get the inferattr helper + if (it != node->control_deps.end()) { if (subgraph_set.count(it->get())) { new_control_deps.push_back(*it); } else { @@ -160,8 +153,7 @@ Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& sub uint32_t node_id = subgraph_node->control_deps.size(); subgraph_node->control_deps.push_back(*it); auto helper_node = op::MakeNode("_FusedOpOutHelper", - subgraph_node->attrs.name + "_" - + node->attrs.name + "_outhelper", + "FusedOp_" + node->attrs.name + "_outhelper", nullptr, nullptr, nullptr); @@ -180,6 +172,17 @@ Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& sub } }); + std::ostringstream name_oss; + // the name of the new node will be the concatenation of all the node names in the subgraph + DFSVisit(subgraph.outputs, [&name_oss](const nnvm::NodePtr n) { + if (n->op() != nullptr) { + name_oss << n->op()->name << "_"; + } + }); + auto subgraph_name = name_oss.str(); + subgraph_name.pop_back(); + subgraph_node->attrs.name = subgraph_name; + const auto& index = subgraph.indexed_graph(); DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const nnvm::NodePtr& node) { for (auto &e : node->control_deps) { diff --git a/src/ndarray/ndarray_function-inl.h b/src/ndarray/ndarray_function-inl.h index d494f08..6ac586a 100644 --- a/src/ndarray/ndarray_function-inl.h +++ b/src/ndarray/ndarray_function-inl.h @@ -379,7 +379,7 @@ void EvalRandom<DEVICE, GenNegBinomialDistribution>( template<> void Eval<DEVICE>(const real_t &rhs, TBlob *ret, RunContext ctx) { mshadow::Stream<DEVICE> *s = ctx.get_stream<DEVICE>(); - MSHADOW_TYPE_SWITCH(ret->type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(ret->type_flag_, DType, { ret->FlatTo2D<DEVICE, DType>(s) = DType(rhs); }); } diff --git a/src/operator/fusion/fused_op-inl.h b/src/operator/fusion/fused_op-inl.h index 3085bfd..2966fe2 100644 --- a/src/operator/fusion/fused_op-inl.h +++ b/src/operator/fusion/fused_op-inl.h @@ -982,11 +982,9 @@ const char kernel_begin[] = R"code( const int tid = threadIdx.x + blockIdx.x * blockDim.x; for (int i = tid; i < N; i+= gridDim.x * blockDim.x) { int offset = i*nvec; - )code"; -const char kernel_end[] = R"code( -} +const char kernel_end[] = R"code(} } )code"; diff --git a/src/operator/fusion/fused_op.cc b/src/operator/fusion/fused_op.cc index 071215b..5c83c30 100644 --- a/src/operator/fusion/fused_op.cc +++ b/src/operator/fusion/fused_op.cc @@ -49,31 +49,30 @@ void FusedOpParamParser(nnvm::NodeAttrs* attrs) { attrs->parsed = FusedOpPtr(new FusedOp(attrs, param)); } -FusedOp::FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig& config) { - this->inputs_ = std::vector<FusedOpEntry>(config.num_inputs); - this->outputs_ = std::vector<FusedOpEntry>(config.num_outputs); - this->subgraph_ = nnvm::Graph(); - this->subgraph_.outputs = attrs->subgraphs[0]->outputs; - this->initialized_ = false; - this->cc_major_ = -1; - this->cc_minor_ = -1; +FusedOp::FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig& config) : + initialized_(false), + kernel_function_dev_id_(-1) { + inputs_ = std::vector<FusedOpEntry>(config.num_inputs); + outputs_ = std::vector<FusedOpEntry>(config.num_outputs); + subgraph_ = nnvm::Graph(); + subgraph_.outputs = attrs->subgraphs[0]->outputs; } bool FusedOp::InferShape(const nnvm::NodeAttrs &attrs, std::vector<mxnet::TShape> *in_attrs, std::vector<mxnet::TShape> *out_attrs) { - this->subgraph_.attrs.erase("shape"); - this->subgraph_.attrs.erase("shape_inputs"); + subgraph_.attrs.erase("shape"); + subgraph_.attrs.erase("shape_inputs"); std::vector<mxnet::TShape> input_shapes(*in_attrs); - this->subgraph_ = mxnet::exec::InferShape(std::move(this->subgraph_), - std::move(input_shapes), - "__shape__"); + subgraph_ = mxnet::exec::InferShape(std::move(subgraph_), + std::move(input_shapes), + "__shape__"); - const auto& g = this->subgraph_.indexed_graph(); + const auto& g = subgraph_.indexed_graph(); const auto& input_nids = g.input_nodes(); std::vector<mxnet::TShape> out_shapes; - const std::vector<mxnet::TShape> shapes = this->subgraph_.GetAttr<mxnet::ShapeVector>("shape"); + const std::vector<mxnet::TShape> shapes = subgraph_.GetAttr<mxnet::ShapeVector>("shape"); for (auto& e : g.outputs()) { out_shapes.push_back(shapes[g.entry_id(e)]); } @@ -105,18 +104,18 @@ bool FusedOp::InferShape(const nnvm::NodeAttrs &attrs, bool FusedOp::InferType(const nnvm::NodeAttrs &attrs, std::vector<int> *in_attrs, std::vector<int> *out_attrs) { - this->subgraph_.attrs.erase("dtype"); - this->subgraph_.attrs.erase("dtype_inputs"); + subgraph_.attrs.erase("dtype"); + subgraph_.attrs.erase("dtype_inputs"); std::vector<int> input_types(*in_attrs); - this->subgraph_ = mxnet::exec::InferType(std::move(this->subgraph_), - std::move(input_types), - "__dtype__"); + subgraph_ = mxnet::exec::InferType(std::move(subgraph_), + std::move(input_types), + "__dtype__"); - const auto& g = this->subgraph_.indexed_graph(); + const auto& g = subgraph_.indexed_graph(); const auto& input_nids = g.input_nodes(); std::vector<int> out_types; - const std::vector<int> types = this->subgraph_.GetAttr<nnvm::DTypeVector>("dtype"); + const std::vector<int> types = subgraph_.GetAttr<nnvm::DTypeVector>("dtype"); for (auto& e : g.outputs()) { out_types.push_back(types[g.entry_id(e)]); } @@ -149,10 +148,9 @@ template <typename Attr> std::tuple<const nnvm::NodePtr, std::vector<Attr>, std::vector<Attr>> - FusedOp::GetAttrs(const std::string& attr_name, - const uint32_t node_id) { - const auto& g = this->subgraph_.indexed_graph(); - const std::vector<Attr> attrs = this->subgraph_.GetAttr<std::vector<Attr>>(attr_name); +FusedOp::GetAttrs(const std::string& attr_name, const uint32_t node_id) { + const auto& g = subgraph_.indexed_graph(); + const std::vector<Attr> attrs = subgraph_.GetAttr<std::vector<Attr>>(attr_name); const auto& node = g[node_id]; std::vector<Attr> inputs, outputs; for (const auto& e : node.inputs) { diff --git a/src/operator/fusion/fused_op.cu b/src/operator/fusion/fused_op.cu index f6df38b..78988f1 100644 --- a/src/operator/fusion/fused_op.cu +++ b/src/operator/fusion/fused_op.cu @@ -163,9 +163,33 @@ void AddPointerAndShape(const TBlob& data, }); } +// Obtain compilation log from the program. +std::string GetCompileLog(nvrtcProgram program) { + size_t log_size_including_null; + NVRTC_CALL(nvrtcGetProgramLogSize(program, &log_size_including_null)); + // For most std::string implementations, this is probably 1 char bigger than needed. OK though. + std::string log(log_size_including_null, '\0'); + NVRTC_CALL(nvrtcGetProgramLog(program, &log[0])); + // Make sure the string reflects the true size (so minus the null terminator). + log.resize(log_size_including_null - 1); + return log; +} + +// Obtain compilation result (ptx assembly) from the program. +std::string GetPtx(nvrtcProgram program) { + size_t ptx_size_including_null; + NVRTC_CALL(nvrtcGetPTXSize(program, &ptx_size_including_null)); + // For most std::string implementations, this is probably 1 char bigger than needed. OK though. + std::string ptx(ptx_size_including_null, '\0'); + NVRTC_CALL(nvrtcGetPTX(program, &ptx[0])); + // Make sure the string reflects the true size (so minus the null terminator). + ptx.resize(ptx_size_including_null - 1); + return ptx; +} + } // namespace -void FusedOp::GenerateCode(int kernel_index, const std::vector<OpReqType> &req, +std::string FusedOp::GenerateCode(const std::vector<OpReqType> &req, const std::vector<int> &in_dtypes, const std::vector<int> &out_dtypes, const std::vector<int> &in_ndims, @@ -175,7 +199,7 @@ void FusedOp::GenerateCode(int kernel_index, const std::vector<OpReqType> &req, const int nvec, const std::string &kernel_name, std::vector<uint32_t>* check_shapes) { - const auto& g = this->subgraph_.indexed_graph(); + const auto& g = subgraph_.indexed_graph(); std::string code = ""; int temp_name_counter = 0; using NodeEntry = nnvm::IndexedGraph::NodeEntry; @@ -459,16 +483,11 @@ void FusedOp::GenerateCode(int kernel_index, const std::vector<OpReqType> &req, ++counter; } - this->code_[kernel_index] = code; - // Add boilerplate and type information - if (dmlc::GetEnv("MXNET_FUSION_VERBOSE", false)) { - LOG(INFO) << code_[kernel_index]; - } std::string kernel_params = ""; std::string tensor_params = ""; nnvm::Symbol sym; - sym.outputs = this->subgraph_.outputs; + sym.outputs = subgraph_.outputs; const std::vector<std::string> input_names = sym.ListInputNames(nnvm::Symbol::kAll); size_t num_params = in_dtypes.size() + out_dtypes.size(); size_t i = 0; @@ -513,85 +532,102 @@ void FusedOp::GenerateCode(int kernel_index, const std::vector<OpReqType> &req, } kernel_params += tensor_params; - code_[kernel_index] = std::string(fusion::fp16_support_string) + "\n" + - fusion::type_support_string + "\n" + - fusion::function_definitions + "\n" + - fusion::backward_function_definitions + "\n" + - aux_code + "\n" + - "__launch_bounds__(" + std::to_string(FusedOp::NTHREADS) + ")\n" + - "__global__ void FusedKernel_" + kernel_name + - "(size_t N, " + kernel_params + ") {\n" + - fusion::kernel_begin + "\n" + - code_[kernel_index] + "\n" + - fusion::kernel_end; + // Create kernel source (minus the common header) + return aux_code + "\n" + + "__launch_bounds__(" + std::to_string(FusedOp::NTHREADS) + ")\n" + + "__global__ void FusedKernel_" + kernel_name + + "(size_t N, " + kernel_params + ") {\n" + + fusion::kernel_begin + "\n" + + code + "\n" + + fusion::kernel_end; } -void FusedOp::CompileCode(int kernel_index, const std::string &kernel_name) { +CUfunction FusedOp::CompileCode(const std::string &code, + const std::string &kernel_name, + int dev_id) { // Guard NVRTC calls std::lock_guard<std::mutex> lock_nvrtc(mutex_); - nvrtcProgram program; - NVRTC_CALL( - nvrtcCreateProgram(&program, // prog - &code_[kernel_index][0], // buffer - (kernel_name + "_kernel.cu").c_str(), // name - 0, // num headers - NULL, // headers - NULL)); // include names - std::string gpu_arch = "--gpu-architecture=compute_" + - std::to_string(this->cc_major_) + - std::to_string(this->cc_minor_); - - const char *opts[] = {gpu_arch.c_str(), - "--std=c++11", - "-default-device"}; - const std::string kernel_name_demangled = "FusedKernel_" + kernel_name; - NVRTC_CALL(nvrtcAddNameExpression(program, (kernel_name_demangled).c_str())); - - nvrtcResult compileResult = nvrtcCompileProgram(program, // prog - 3, // num options - opts); // options - // Obtain compilation log from the program. - size_t log_size; - NVRTC_CALL(nvrtcGetProgramLogSize(program, &log_size)); - std::string log(log_size, '\0'); - NVRTC_CALL(nvrtcGetProgramLog(program, &log[0])); - CHECK_EQ(compileResult, NVRTC_SUCCESS) - << "NVRTC Compilation failed. Please set environment variable MXNET_USE_FUSION to 0.\n" << log; - // Obtain PTX from the program. - size_t ptx_size; - NVRTC_CALL(nvrtcGetPTXSize(program, &ptx_size)); - ptx_[kernel_index].reserve(ptx_size); - NVRTC_CALL(nvrtcGetPTX(program, &ptx_[kernel_index][0])); - const char *name; - NVRTC_CALL(nvrtcGetLoweredName(program, - kernel_name_demangled.c_str(), - &name)); - kernel_name_[kernel_index] = name; - // Destroy the program. - NVRTC_CALL(nvrtcDestroyProgram(&program)); - int device; - CUdevice cu_device; - CUcontext context; - CUmodule module; - CUDA_CALL(cudaGetDevice(&device)); - CUDA_DRIVER_CALL(cuDeviceGet(&cu_device, device)); - CUDA_DRIVER_CALL(cuDevicePrimaryCtxRetain(&context, cu_device)); - CUDA_DRIVER_CALL(cuModuleLoadData(&module, &ptx_[kernel_index][0])); - CUDA_DRIVER_CALL(cuModuleGetFunction(&kernel_[kernel_index], - module, - kernel_name_[kernel_index].c_str())); + // Local class for value type of compile cache + struct KernelInfo { + std::string mangled_name; + std::string ptx; + std::vector<CUfunction> functions; + }; + // Maps from the cuda source code (minus header) to the ptx and jit-compiled CUfunctions. + using KernelCache = std::map<std::string, KernelInfo>; + // Per-gpu-architecture compiled kernel cache with jit-compiled function for each device context + static std::map<int32_t, KernelCache> compiled_kernels; + int sm_arch = SMArch(dev_id); + KernelCache& compiled_kernels_this_arch = compiled_kernels[sm_arch]; // make null map as needed + KernelInfo& kinfo = compiled_kernels_this_arch[code]; // make KernelInfo as needed + if (kinfo.ptx.size() == 0) { + // It's the first time we've seen this kernel, so we need to generate the ptx and mangled_name. + static std::string common_header = + std::string(fusion::fp16_support_string) + "\n" + + fusion::type_support_string + "\n" + + fusion::function_definitions + "\n" + + fusion::backward_function_definitions + "\n"; + std::string code_with_header = common_header + code; + // If verbose mode, output kernel source, though not including the common header + if (dmlc::GetEnv("MXNET_FUSION_VERBOSE", false)) { + LOG(INFO) << "\n" << std::string(80, '-') << "\n" << code; + } + if (compiled_kernels_this_arch.size() == CACHESIZE_WARN_THRESHOLD + 1 && + dmlc::GetEnv("MXNET_FUSION_SIZE_WARNING", true)) { + LOG(WARNING) << "The number of different fused ops exceeds " << CACHESIZE_WARN_THRESHOLD + << ". Set MXNET_FUSION_SIZE_WARNING=0 to quiet this warning."; + } + nvrtcProgram program; + NVRTC_CALL(nvrtcCreateProgram(&program, // prog + &code_with_header[0], // buffer + (kernel_name + "_kernel.cu").c_str(), // name + 0, // num headers + NULL, // headers + NULL)); // include names + + std::string gpu_arch_arg = "--gpu-architecture=compute_" + std::to_string(sm_arch); + const char *opts[] = {gpu_arch_arg.c_str(), + "--std=c++11", + "-default-device"}; + const std::string kernel_name_demangled = "FusedKernel_" + kernel_name; + NVRTC_CALL(nvrtcAddNameExpression(program, (kernel_name_demangled).c_str())); + + nvrtcResult compileResult = nvrtcCompileProgram(program, // prog + 3, // num options + opts); // options + CHECK_EQ(compileResult, NVRTC_SUCCESS) + << "NVRTC Compilation failed. Please set environment variable MXNET_USE_FUSION to 0.\n" + << GetCompileLog(program); + + kinfo.ptx = GetPtx(program); + const char *mangled_name; + NVRTC_CALL(nvrtcGetLoweredName(program, + kernel_name_demangled.c_str(), + &mangled_name)); + kinfo.mangled_name = mangled_name; + // Destroy the program. + NVRTC_CALL(nvrtcDestroyProgram(&program)); + } + // Ensure function array is deep enough to index by dev_id + while (kinfo.functions.size() <= static_cast<size_t>(dev_id)) + kinfo.functions.push_back(static_cast<CUfunction>(nullptr)); + // Jit-compile ptx for the device as needed + if (kinfo.functions[dev_id] == static_cast<CUfunction>(nullptr)) { + // Make sure driver context is set to the proper device + CUdevice cu_device; + CUcontext context; + CUDA_DRIVER_CALL(cuDeviceGet(&cu_device, dev_id)); + CUDA_DRIVER_CALL(cuDevicePrimaryCtxRetain(&context, cu_device)); + // Jit-compile ptx for the driver's current context + CUmodule module; + CUDA_DRIVER_CALL(cuModuleLoadData(&module, kinfo.ptx.c_str())); + CUDA_DRIVER_CALL(cuModuleGetFunction(&kinfo.functions[dev_id], + module, + kinfo.mangled_name.c_str())); + } + return kinfo.functions[dev_id]; } -bool FusedOp::CheckComputeCapability(const OpContext &ctx) { - const int dev_id = ctx.run_ctx.ctx.dev_id; - const int cc_major = ComputeCapabilityMajor(dev_id); - const int cc_minor = ComputeCapabilityMinor(dev_id); - - const bool ret = cc_major == this->cc_major_ && cc_minor == this->cc_minor_; - this->cc_major_ = cc_major; - this->cc_minor_ = cc_minor; - return ret; -} void FusedOp::CheckShapesAndTypes(const std::vector<TBlob> &inputs, const std::vector<TBlob> &outputs, @@ -665,23 +701,30 @@ void FusedOp::Forward<gpu>(const nnvm::NodeAttrs& attrs, const auto& node_shapes = intermediate_shapes_[0].internal_attr; const auto& node_dtypes = intermediate_dtypes_[0].internal_attr; - // Check and save compute capability of the current GPU - if (!CheckComputeCapability(ctx)) initialized_ = false; + int dev_id = ctx.run_ctx.ctx.dev_id; + // A change between training and inference modes may require different kernel functions initialized_ = initialized_ && (req == saved_reqs_); saved_reqs_ = req; if (!initialized_) { - this->GenerateCode(0, req, in_dtypes, out_dtypes, in_ndims, out_ndims, + const auto& code = GenerateCode(req, in_dtypes, out_dtypes, in_ndims, out_ndims, node_shapes, node_dtypes, nvec, attrs.name, &check_shape_args_); - this->CompileCode(0, attrs.name); + kernel_functions_[fusion::kGeneral] = CompileCode(code, attrs.name, dev_id); if (check_shape_args_.size() > 0) { - this->GenerateCode(1, req, in_dtypes, out_dtypes, in_ndims, out_ndims, + const auto& code = GenerateCode(req, in_dtypes, out_dtypes, in_ndims, out_ndims, node_shapes, node_dtypes, nvec, attrs.name, NULL); - this->CompileCode(1, attrs.name); + kernel_functions_[fusion::kShapeOptimized] = CompileCode(code, attrs.name, dev_id); } initialized_ = true; + kernel_function_dev_id_ = dev_id; } + + // A change in device would force recompiling, but this is unexpected so signal as an error + if (dev_id != kernel_function_dev_id_) + LOG(FATAL) << "Fused op compiled for device " << kernel_function_dev_id_ + << ", not expecting switch to device " << dev_id; + Stream<gpu>* s = ctx.get_stream<gpu>(); auto stream = Stream<gpu>::GetStream(s); std::vector<void*> args; @@ -713,18 +756,18 @@ void FusedOp::Forward<gpu>(const nnvm::NodeAttrs& attrs, for (auto &ptr : ptrs) { args.push_back(reinterpret_cast<void *>(&ptr)); } - int kernel_index = 0; + int kernel_variant = fusion::kGeneral; if (check_shape_args_.size() > 0) { - kernel_index = 1; + kernel_variant = fusion::kShapeOptimized; for (const auto &shape_id : check_shape_args_) { const auto& shape = node_shapes[shape_id]; if (shape[shape.ndim()-1] % nvec != 0) { - kernel_index = 0; + kernel_variant = fusion::kGeneral; } } } CUDA_DRIVER_CALL( - cuLaunchKernel(kernel_[kernel_index], + cuLaunchKernel(kernel_functions_[kernel_variant], num_blocks, 1, 1, // grid dim FusedOp::NTHREADS, 1, 1, // block dim 0, stream, // shared mem and stream diff --git a/src/operator/fusion/fused_op.h b/src/operator/fusion/fused_op.h index 035e5432..24603ac 100644 --- a/src/operator/fusion/fused_op.h +++ b/src/operator/fusion/fused_op.h @@ -34,6 +34,12 @@ namespace mxnet { +namespace fusion { + enum KernelVariants {kGeneral, kShapeOptimized, + kNumKernelVariants // Not a variant- leave this at the end + }; +} + struct FusedOpConfig : public dmlc::Parameter<FusedOpConfig> { int num_inputs; int num_outputs; @@ -53,6 +59,7 @@ struct FusedOpEntry { class FusedOp { public: static const int NTHREADS = 512; + static const int CACHESIZE_WARN_THRESHOLD = 10000; explicit FusedOp(const nnvm::NodeAttrs* attrs, const FusedOpConfig& config); ~FusedOp() {} @@ -120,20 +127,20 @@ class FusedOp { } private: - void GenerateCode(int kernel_index, - const std::vector<OpReqType> &req, - const std::vector<int> &in_dtypes, - const std::vector<int> &out_dtypes, - const std::vector<int> &in_ndims, - const std::vector<int> &out_ndims, - const mxnet::ShapeVector &node_shapes, - const std::vector<int> &node_dtypes, - const int nvec, - const std::string& kernel_name, - std::vector<uint32_t> *check_shapes); - void CompileCode(int kernel_index, - const std::string &kernel_name); - bool CheckComputeCapability(const OpContext &ctx); + std::string GenerateCode(const std::vector<OpReqType> &req, + const std::vector<int> &in_dtypes, + const std::vector<int> &out_dtypes, + const std::vector<int> &in_ndims, + const std::vector<int> &out_ndims, + const mxnet::ShapeVector &node_shapes, + const std::vector<int> &node_dtypes, + const int nvec, + const std::string& kernel_name, + std::vector<uint32_t> *check_shapes); + + CUfunction CompileCode(const std::string &code, + const std::string &kernel_name, int dev_id); + void CheckShapesAndTypes(const std::vector<TBlob> &inputs, const std::vector<TBlob> &outputs, std::vector<int> *in_dtypes, @@ -145,7 +152,6 @@ class FusedOp { std::vector<FusedOpEntry> inputs_; std::vector<FusedOpEntry> outputs_; - std::string code_[2]; nnvm::Graph subgraph_; template <typename T> @@ -173,12 +179,9 @@ class FusedOp { std::vector<uint32_t> extra_shape_args_; std::vector<uint32_t> check_shape_args_; - std::string ptx_[2]; - std::string kernel_name_[2]; - CUfunction kernel_[2]; + CUfunction kernel_functions_[fusion::kNumKernelVariants]; bool initialized_; - int cc_major_; - int cc_minor_; + int kernel_function_dev_id_; static std::mutex mutex_; std::mutex my_mutex_; diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 1ece97b..d563f25 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -97,6 +97,18 @@ using std::is_integral; } \ } +#define MXNET_BINARY_MATH_OP_NC_WITH_BOOL(name, expr) \ + struct name : public mxnet_op::tunable { \ + template<typename DType, \ + typename std::enable_if<!std::is_same<DType, bool>::value, int>::type = 0> \ + MSHADOW_XINLINE static DType Map(DType a, DType b) { \ + return (expr); \ + } \ + MSHADOW_XINLINE static bool Map(bool a, bool b) { \ + return (expr); \ + } \ + } + #define MXNET_BINARY_LOGIC_OP_NC(name, expr) \ struct name : public mxnet_op::tunable { \ template<typename DType> \ @@ -192,13 +204,107 @@ MXNET_BINARY_MATH_OP_NC(left, a); MXNET_BINARY_MATH_OP_NC(right, b); -MXNET_BINARY_MATH_OP_NC(mul, a * b); +#ifndef _WIN32 +struct mixed_plus { + template<typename DType, + typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return static_cast<mshadow::half::half_t>(a) + b; + } + + template<typename DType, + typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || + std::is_integral<DType>::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return static_cast<float>(a) + b; + } + + template<typename DType, + typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || + std::is_same<DType, float>::value || + std::is_integral<DType>::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return static_cast<double>(a) + b; + } +}; + +struct mixed_minus { + template<typename DType, + typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return static_cast<mshadow::half::half_t>(a) - b; + } + + template<typename DType, + typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || + std::is_integral<DType>::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return static_cast<float>(a) - b; + } + + template<typename DType, + typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || + std::is_same<DType, float>::value || + std::is_integral<DType>::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return static_cast<double>(a) - b; + } +}; + +struct mixed_rminus { + template<typename DType, + typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return b - static_cast<mshadow::half::half_t>(a); + } + + template<typename DType, + typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || + std::is_integral<DType>::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return b - static_cast<float>(a); + } + + template<typename DType, + typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || + std::is_same<DType, float>::value || + std::is_integral<DType>::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return b - static_cast<double>(a); + } +}; + +struct mixed_mul { + template<typename DType, + typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> + MSHADOW_XINLINE static mshadow::half::half_t Map(DType a, mshadow::half::half_t b) { + return static_cast<mshadow::half::half_t>(a) * b; + } + + template<typename DType, + typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || + std::is_integral<DType>::value, int>::type = 0> + MSHADOW_XINLINE static float Map(DType a, float b) { + return static_cast<float>(a) * b; + } + + template<typename DType, + typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || + std::is_same<DType, float>::value || + std::is_integral<DType>::value, int>::type = 0> + MSHADOW_XINLINE static double Map(DType a, double b) { + return static_cast<double>(a) * b; + } +}; +#endif + +MXNET_BINARY_MATH_OP_NC_WITH_BOOL(mul, a * b); -MXNET_BINARY_MATH_OP_NC(div, a / b); +MXNET_BINARY_MATH_OP_NC_WITH_BOOL(div, a / b); -MXNET_BINARY_MATH_OP_NC(plus, a + b); +MXNET_BINARY_MATH_OP_NC_WITH_BOOL(plus, a + b); -MXNET_BINARY_MATH_OP_NC(minus, a - b); +MXNET_BINARY_MATH_OP_NC_WITH_BOOL(minus, a - b); MXNET_UNARY_MATH_OP(negation, -a); diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 5d297a5..b15117f 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -859,14 +859,17 @@ struct op_with_req { /*! \brief inputs are two tensors with a float output tensor */ template<typename DType, - typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> + typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || + std::is_integral<DType>::value, int>::type = 0> MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float *rhs) { KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); } /*! \brief inputs are two tensors with a double output tensor */ template<typename DType, - typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> + typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || + std::is_same<DType, float>::value || + std::is_integral<DType>::value, int>::type = 0> MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double *rhs) { KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); } @@ -883,14 +886,17 @@ struct op_with_req { /*! \brief inputs are two tensors with a float output tensor */ template<typename DType, - typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> + typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || + std::is_integral<DType>::value, int>::type = 0> MSHADOW_XINLINE static void Map(index_t i, float *out, const DType *lhs, const float value) { KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value)); } /*! \brief inputs are two tensors with a double output tensor */ template<typename DType, - typename std::enable_if<std::is_integral<DType>::value, int>::type = 0> + typename std::enable_if<std::is_same<DType, mshadow::half::half_t>::value || + std::is_same<DType, float>::value || + std::is_integral<DType>::value, int>::type = 0> MSHADOW_XINLINE static void Map(index_t i, double *out, const DType *lhs, const double value) { KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], value)); } diff --git a/src/operator/numpy/np_broadcast_reduce_op.h b/src/operator/numpy/np_broadcast_reduce_op.h index 4951f62..3566323 100644 --- a/src/operator/numpy/np_broadcast_reduce_op.h +++ b/src/operator/numpy/np_broadcast_reduce_op.h @@ -52,6 +52,7 @@ struct NumpyReduceAxesParam : public dmlc::Parameter<NumpyReduceAxesParam> { .add_enum("int8", mshadow::kInt8) .add_enum("int32", mshadow::kInt32) .add_enum("int64", mshadow::kInt64) + .add_enum("bool", mshadow::kBool) .set_default(dmlc::optional<int>()) .describe("The type of the returned array and of the accumulator in which the elements are " "summed. The dtype of a is used by default unless a has an integer dtype of less " @@ -221,15 +222,15 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, const std::vector<TBlob>& inputs, const std::vector<OpReqType>& req, const std::vector<TBlob>& outputs) { + using namespace mshadow; if (req[0] == kNullOp) return; const NumpyReduceAxesParam& param = nnvm::get<NumpyReduceAxesParam>(attrs.parsed); if (param.initial.has_value()) { LOG(FATAL) << "initial is not supported yet"; } + Stream<xpu>* s = ctx.get_stream<xpu>(); if (inputs[0].shape_.Size() == 0 && outputs[0].shape_.Size() != 0) { using namespace mxnet_op; - using namespace mshadow; - Stream<xpu>* s = ctx.get_stream<xpu>(); MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { Kernel<set_zero, xpu>::Launch(s, outputs[0].shape_.Size(), outputs[0].dptr<DType>()); }); @@ -246,6 +247,13 @@ void NumpyReduceAxesCompute(const nnvm::NodeAttrs& attrs, LOG(FATAL) << "Only reduce op: `sum` is supported for boolean ndarrays"; } TVMOpReduce(ctx, inputs[0], param.axis, outputs[0], req[0], reducer_name); + if (normalize) { + using namespace mshadow::expr; + MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + auto out = outputs[0].FlatTo2D<xpu, OType>(s); + out /= scalar<OType>(inputs[0].Size()/outputs[0].Size()); + }); + } return; } #endif diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index 435fe1d..fb13356 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -257,13 +257,14 @@ inline bool NumpyMeanType(const nnvm::NodeAttrs& attrs, const NumpyReduceAxesParam ¶m = nnvm::get<NumpyReduceAxesParam>(attrs.parsed); if (param.dtype.has_value()) { - if (IsIntType(in_attrs->at(0)) && !IsIntType(param.dtype.value())) { - LOG(FATAL) << "Output cannot be float type when input is integer type for now"; - } TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value()); } else { - TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); - TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + if (common::is_float(in_attrs->at(0))) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32); + } } return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index c206ad4..a76e59d 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -23,8 +23,7 @@ * \brief CPU Implementation of basic functions for elementwise numpy binary broadcast operator. */ -#include "../tensor/elemwise_binary_broadcast_op.h" -#include "../tensor/elemwise_binary_scalar_op.h" +#include "./np_elemwise_broadcast_op.h" namespace mxnet { namespace op { @@ -55,17 +54,99 @@ bool NumpyBinaryScalarType(const nnvm::NodeAttrs& attrs, .add_argument("data", "NDArray-or-Symbol", "source input") \ .add_argument("scalar", "float", "scalar input") +bool NumpyBinaryMixedPrecisionType(const nnvm::NodeAttrs& attrs, + std::vector<int>* in_attrs, + std::vector<int>* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + const int ltype = in_attrs->at(0); + const int rtype = in_attrs->at(1); + if (ltype != -1 && rtype != -1 && (ltype != rtype)) { + // Only when both input types are known and not the same, we enter the mixed-precision mode + TYPE_ASSIGN_CHECK(*out_attrs, 0, common::np_binary_out_infer_type(ltype, rtype)); + } else { + return ElemwiseType<2, 1>(attrs, in_attrs, out_attrs); + } + return true; +} -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_add) -.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::plus>) +#ifndef _WIN32 +#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(2) \ + .set_num_outputs(1) \ + .set_attr<nnvm::FListInputNames>("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector<std::string>{"lhs", "rhs"}; \ + }) \ + .set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape) \ + .set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType) \ + .set_attr<nnvm::FInplaceOption>("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \ + }) \ + .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ + .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") +#else +#define MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(name) \ + NNVM_REGISTER_OP(name) \ + .set_num_inputs(2) \ + .set_num_outputs(1) \ + .set_attr<nnvm::FListInputNames>("FListInputNames", \ + [](const NodeAttrs& attrs) { \ + return std::vector<std::string>{"lhs", "rhs"}; \ + }) \ + .set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape) \ + .set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType) \ + .set_attr<nnvm::FInplaceOption>("FInplaceOption", \ + [](const NodeAttrs& attrs){ \ + return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}}; \ + }) \ + .set_attr<FResourceRequest>("FResourceRequest", \ + [](const NodeAttrs& attrs) { \ + return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; \ + }) \ + .add_argument("lhs", "NDArray-or-Symbol", "First input to the function") \ + .add_argument("rhs", "NDArray-or-Symbol", "Second input to the function") +#endif + +MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_add) +#ifndef _WIN32 +.set_attr<FCompute>( + "FCompute<cpu>", + NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::plus, op::mshadow_op::mixed_plus, + op::mshadow_op::mixed_plus>) +#else +.set_attr<FCompute>( + "FCompute<cpu>", + NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::plus>) +#endif .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_add"}); -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_subtract) -.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::minus>) +MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_subtract) +#ifndef _WIN32 +.set_attr<FCompute>( + "FCompute<cpu>", + NumpyBinaryBroadcastCompute<cpu, op::mshadow_op::minus, op::mshadow_op::mixed_minus, + op::mshadow_op::mixed_rminus>) +#else +.set_attr<FCompute>( + "FCompute<cpu>", + NumpyBinaryBroadcastCompute<cpu, op::mshadow_op::minus>) +#endif .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_broadcast_sub"}); -MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_multiply) -.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, op::mshadow_op::mul>) +MXNET_OPERATOR_REGISTER_NP_BINARY_MIXED_PRECISION(_npi_multiply) +#ifndef _WIN32 +.set_attr<FCompute>( + "FCompute<cpu>", + NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul, + op::mshadow_op::mixed_mul>) +#else +.set_attr<FCompute>( + "FCompute<cpu>", + NumpyBinaryBroadcastComputeWithBool<cpu, op::mshadow_op::mul>) +#endif .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_mul"}); MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_mod) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index a682ec9..a0a277d 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -22,20 +22,47 @@ * \file np_elemwise_broadcast_op.cu * \brief GPU Implementation of basic functions for elementwise binary broadcast operator. */ -#include "../tensor/elemwise_binary_broadcast_op.h" -#include "../tensor/elemwise_binary_scalar_op.h" + +#include "./np_elemwise_broadcast_op.h" namespace mxnet { namespace op { NNVM_REGISTER_OP(_npi_add) -.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::plus>); +#ifndef _WIN32 +.set_attr<FCompute>( + "FCompute<gpu>", + NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::plus, op::mshadow_op::mixed_plus, + op::mshadow_op::mixed_plus>); +#else +.set_attr<FCompute>( + "FCompute<gpu>", + NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::plus>); +#endif NNVM_REGISTER_OP(_npi_subtract) -.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::minus>); +#ifndef _WIN32 +.set_attr<FCompute>( + "FCompute<gpu>", + NumpyBinaryBroadcastCompute<gpu, op::mshadow_op::minus, op::mshadow_op::mixed_minus, + op::mshadow_op::mixed_rminus>); +#else +.set_attr<FCompute>( + "FCompute<gpu>", + NumpyBinaryBroadcastCompute<gpu, op::mshadow_op::minus>); +#endif NNVM_REGISTER_OP(_npi_multiply) -.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, op::mshadow_op::mul>); +#ifndef _WIN32 +.set_attr<FCompute>( + "FCompute<gpu>", + NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul, op::mshadow_op::mixed_mul, + op::mshadow_op::mixed_mul>); +#else +.set_attr<FCompute>( + "FCompute<gpu>", + NumpyBinaryBroadcastComputeWithBool<gpu, op::mshadow_op::mul>); +#endif NNVM_REGISTER_OP(_npi_mod) .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::mod>); diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h b/src/operator/numpy/np_elemwise_broadcast_op.h new file mode 100644 index 0000000..1a4596f --- /dev/null +++ b/src/operator/numpy/np_elemwise_broadcast_op.h @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_elemwise_binary_op.h + * \brief Function definition of elemwise and broadcast operators + */ +#ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_ +#define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_ + +#include <vector> +#include <string> + +#include "../tensor/elemwise_binary_broadcast_op.h" +#include "../tensor/elemwise_binary_scalar_op.h" + +namespace mxnet { +namespace op { + +inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) { + LOG(FATAL) << "Operator " << op_name << " does not support combination of " + << common::dtype_string(dtype1) << " with " << common::dtype_string(dtype2) + << " yet..."; +} + +#ifndef _WIN32 +template<typename xpu, typename OP> +void MixedAllRealBinaryElemwiseCompute(const std::string& op_name, + const OpContext& ctx, + const TBlob& lhs, + const TBlob& rhs, + const TBlob& out, + const OpReqType req) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(lhs.type_flag_, out.type_flag_); + + Stream<xpu> *s = ctx.get_stream<xpu>(); + + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, { + const size_t size = (ElemwiseBinaryOp::minthree(out.Size(), lhs.Size(), rhs.Size()) + + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes; + if (size == 0) return; + + switch (lhs.type_flag_) { + case mshadow::kFloat32: + { + if (rhs.type_flag_ == mshadow::kFloat16) { + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch( + s, size, out.dptr<float>(), rhs.dptr<mshadow::half::half_t>(), + lhs.dptr<float>()); + }); + } else { + PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_); + } + break; + } + case mshadow::kFloat64: + { + if (rhs.type_flag_ == mshadow::kFloat16) { + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch( + s, size, out.dptr<double>(), rhs.dptr<mshadow::half::half_t>(), + lhs.dptr<double>()); + }); + } else if (rhs.type_flag_ == mshadow::kFloat32) { + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch( + s, size, out.dptr<double>(), rhs.dptr<float>(), + lhs.dptr<double>()); + }); + } else { + PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_); + } + break; + } + default: + { + PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_); + break; + } + } + }); +} + +template<typename xpu, typename OP> +void MixedIntRealBinaryElemwiseCompute(const OpContext& ctx, + const TBlob& lhs, + const TBlob& rhs, + const TBlob& out, + const OpReqType req) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(lhs.type_flag_, out.type_flag_); + + Stream<xpu> *s = ctx.get_stream<xpu>(); + + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, FType, { + const size_t size = (ElemwiseBinaryOp::minthree(out.Size(), lhs.Size(), rhs.Size()) + + DataType<FType>::kLanes - 1) / DataType<FType>::kLanes; + if (size == 0) return; + + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, IType, { + MXNET_ASSIGN_REQ_SWITCH(req, Req, { + Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch( + s, size, out.dptr<FType>(), rhs.dptr<IType>(), + lhs.dptr<FType>()); + }); + }); + }); +} + +template<typename xpu, typename LOP, typename ROP> +void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + + if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { + if (lhs.type_flag_ == out.type_flag_) { + MixedAllRealBinaryElemwiseCompute<xpu, ROP>(attrs.op->name, ctx, lhs, rhs, out, req[0]); + } else { + MixedAllRealBinaryElemwiseCompute<xpu, LOP>(attrs.op->name, ctx, rhs, lhs, out, req[0]); + } + } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + if (lhs.type_flag_ == out.type_flag_) { + MixedIntRealBinaryElemwiseCompute<xpu, ROP>(ctx, lhs, rhs, out, req[0]); + } else { + MixedIntRealBinaryElemwiseCompute<xpu, LOP>(ctx, rhs, lhs, out, req[0]); + } + } else { + PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); + } +} + +template<typename xpu, typename OP> +void MixedAllRealBinaryBroadcastCompute(const std::string& op_name, + const OpContext& ctx, + const TBlob& lhs, + const TBlob& rhs, + const TBlob& out, + const OpReqType req, + const int ndim, + const mxnet::TShape& new_oshape, + const mxnet::TShape& new_lshape, + const mxnet::TShape& new_rshape) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(lhs.type_flag_, out.type_flag_); + + Stream<xpu> *s = ctx.get_stream<xpu>(); + + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape<NDim> oshape = new_oshape.get<NDim>(); + mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>()); + mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>()); + switch (lhs.type_flag_) { + case mshadow::kFloat32: + { + if (rhs.type_flag_ == mshadow::kFloat16) { + mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>:: + template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape, + rhs.dptr<mshadow::half::half_t>(), lhs.dptr<float>(), out.dptr<float>()); + } else { + PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_); + } + break; + } + case mshadow::kFloat64: + { + if (rhs.type_flag_ == mshadow::kFloat16) { + mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>:: + template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape, + rhs.dptr<mshadow::half::half_t>(), lhs.dptr<double>(), out.dptr<double>()); + } else if (rhs.type_flag_ == mshadow::kFloat32) { + mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>:: + template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape, + rhs.dptr<float>(), lhs.dptr<double>(), out.dptr<double>()); + } else { + PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_); + } + break; + } + default: + { + PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_); + break; + } + } + }); +} +#endif + +#ifndef _WIN32 +template<typename xpu, typename OP, typename LOP, typename ROP> +#else +template<typename xpu, typename OP> +#endif +void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + +#ifndef _WIN32 + mxnet::TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + MixedBinaryElemwiseCompute<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs); + } else { + mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); + if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) { + if (lhs.type_flag_ == out.type_flag_) { + MixedAllRealBinaryBroadcastCompute<xpu, ROP>( + attrs.op->name, ctx, lhs, rhs, out, req[0], ndim, new_oshape, new_lshape, new_rshape); + } else { + MixedAllRealBinaryBroadcastCompute<xpu, LOP>( + attrs.op->name, ctx, rhs, lhs, out, req[0], ndim, new_oshape, new_rshape, new_lshape); + } + } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + CHECK(lhs.type_flag_ == out.type_flag_ || rhs.type_flag_ == out.type_flag_) + << "One of the input type should be the same as the output"; + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape<NDim> oshape = new_oshape.get<NDim>(); + mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>()); + mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>()); + if (lhs.type_flag_ == out.type_flag_) { + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, LType, { + MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, { + mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, ROP>, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], rstride, lstride, oshape, + rhs.dptr<RType>(), lhs.dptr<LType>(), out.dptr<LType>()); + }); + }); + } else { + MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, RType, { + MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, { + mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, LOP>, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + lhs.dptr<LType>(), rhs.dptr<RType>(), out.dptr<RType>()); + }); + }); + } + }); + } else { + PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); + } + } +#else + mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); + if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) { + TBlob temp_tblob; + // one is float, the other is bool + CHECK((out.type_flag_ == lhs.type_flag_) || (out.type_flag_ == rhs.type_flag_)) + << "This case out type should be same as the float type"; + if (lhs.type_flag_ == out.type_flag_) { + MSHADOW_REAL_TYPE_SWITCH(lhs.type_flag_, LType, { + Tensor<xpu, 1, LType> temp_tensor = + ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute<xpu, OP>( + attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs); + } else { + MSHADOW_REAL_TYPE_SWITCH(rhs.type_flag_, RType, { + Tensor<xpu, 1, RType> temp_tensor = + ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), s); + temp_tblob = TBlob(temp_tensor); + }); + CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob}); + BinaryBroadcastCompute<xpu, OP>( + attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs); + } + } else { + PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); + } +#endif +} + +#ifndef _WIN32 +template<typename xpu, typename OP, typename LOP, typename ROP> +#else +template<typename xpu, typename OP> +#endif +void NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + + if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return; + + if (lhs.type_flag_ == rhs.type_flag_) { + BinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs); + return; + } + +#ifndef _WIN32 + MixedBinaryBroadcastCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, outputs); +#else + MixedBinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs); +#endif +} + +#ifndef _WIN32 +template<typename xpu, typename OP, typename LOP, typename ROP> +#else +template<typename xpu, typename OP> +#endif +void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + + if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return; + + if (lhs.type_flag_ == rhs.type_flag_) { + BinaryBroadcastComputeWithBool<xpu, OP>(attrs, ctx, inputs, req, outputs); + return; + } + +#ifndef _WIN32 + MixedBinaryBroadcastCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, outputs); +#else + MixedBinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs); +#endif +} + +template<typename xpu, typename LOP, typename ROP> +void MixedBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + + const TBlob& lhs = inputs[1]; + const TBlob& rhs = inputs[2]; + if (lhs.type_flag_ == rhs.type_flag_) { + BinaryBroadcastBackwardUseIn<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs); + return; + } + + PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_); +} + +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_ diff --git a/src/operator/numpy/np_true_divide.cc b/src/operator/numpy/np_true_divide.cc index d2135be..1e46cc9 100644 --- a/src/operator/numpy/np_true_divide.cc +++ b/src/operator/numpy/np_true_divide.cc @@ -31,7 +31,7 @@ namespace op { int TrueDivideOutType(int ltype, int rtype) { if (common::is_float(ltype) && common::is_float(rtype)) { // If both inputs are float, return the one with the higher precision - return common::more_precise_type(ltype, rtype); + return common::get_more_precise_type(ltype, rtype); } else if (common::is_float(ltype) || common::is_float(rtype)) { // If only one of the inputs is float, return that float type return (common::is_float(ltype)) ? ltype : rtype; @@ -126,7 +126,7 @@ NNVM_REGISTER_OP(_npi_rtrue_divide_scalar) }) #endif .set_attr<FCompute>("FCompute<cpu>", TrueDivideScalarCompute<cpu, mshadow_op::rtrue_divide>) -.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_rdiv_scalar"}) +.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_rdiv_scalar"}) .add_argument("data", "NDArray-or-Symbol", "source input") .add_argument("scalar", "float", "scalar input"); diff --git a/src/operator/operator_tune-inl.h b/src/operator/operator_tune-inl.h index 1dbcf42..122ec04 100644 --- a/src/operator/operator_tune-inl.h +++ b/src/operator/operator_tune-inl.h @@ -124,7 +124,8 @@ class OperatorTune : public OperatorTuneByType<DType> { if (!initialized_) { initialized_ = true; // Generate some random data for calling the operator kernels - data_set_.reserve(0x100); + data_set_ = + std::unique_ptr<DType[]>(reinterpret_cast<DType*>(new char[0x100 * sizeof(DType)])); std::random_device rd; std::mt19937 gen(rd()); if (!std::is_integral<DType>::value) { @@ -136,7 +137,7 @@ class OperatorTune : public OperatorTuneByType<DType> { --n; continue; } - data_set_.emplace_back(val); + data_set_[n] = val; } } else { std::uniform_int_distribution<> dis(-128, 127); @@ -147,7 +148,7 @@ class OperatorTune : public OperatorTuneByType<DType> { --n; continue; } - data_set_.emplace_back(val); + data_set_[n] = val; } } // Use this environment variable to generate new tuning statistics @@ -517,7 +518,7 @@ class OperatorTune : public OperatorTuneByType<DType> { /*! \brief Number of passes to obtain an average */ static constexpr duration_t OUTSIDE_COUNT = (1 << OUTSIDE_COUNT_SHIFT); /*! \brief Random data for timing operator calls */ - static std::vector<DType> data_set_; + static std::unique_ptr<DType[]> data_set_; /*! \brief Operators tuned */ static std::unordered_set<std::string> operator_names_; /*! \brief Arbitary object to modify in OMP loop */ diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index d0642ee..633f630 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -39,7 +39,7 @@ double OperatorTuneBase::tuning_weight_scale_ = 0.0; */ #define IMPLEMENT_OPERATOR_TUNE_STATICS_FOR_TYPE(__typ$) \ template<> bool OperatorTune<__typ$>::initialized_ = false; \ - template<> std::vector<__typ$> OperatorTune<__typ$>::data_set_ = {}; \ + template<> std::unique_ptr<__typ$[]> OperatorTune<__typ$>::data_set_ = nullptr; \ template<> volatile tune::TuningMode OperatorTuneByType<__typ$>::tuning_mode_ = tune::kAuto; \ template<> volatile int OperatorTune<__typ$>::volatile_int_ = 9; /* arbitrary number */ \ template<> std::unordered_set<std::string> OperatorTune<__typ$>::operator_names_({}); \ @@ -314,10 +314,10 @@ IMPLEMENT_UNARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::np_logical_not); IMPLEMENT_UNARY_WORKLOAD_BWD(mxnet::op::mshadow_op::nt); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::clip); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::clip); // NOLINT() -IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::plus); // NOLINT() -IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus); // NOLINT() -IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::mul); // NOLINT() -IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::div); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::plus); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::minus); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::mul); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::div); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::true_divide); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::minus_sign); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rminus); // NOLINT() diff --git a/src/operator/tensor/broadcast_reduce-inl.cuh b/src/operator/tensor/broadcast_reduce-inl.cuh index fc863a9..360420a 100644 --- a/src/operator/tensor/broadcast_reduce-inl.cuh +++ b/src/operator/tensor/broadcast_reduce-inl.cuh @@ -619,8 +619,6 @@ void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req, ReduceImplConfig<ndim> config = ConfigureReduceImpl<ndim, DType>(small.shape_, big.shape_, NULL, NULL); if (safe_acc) { - // TODO(haojin2): Use real-only type swtich for windows temporarily due to CI issues. -#ifndef _WIN32 MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, { typedef typename std::conditional<safe_acc, AType, DataType>::type AccType; MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { @@ -630,17 +628,6 @@ void Reduce(Stream<gpu> *s, const TBlob& small, const OpReqType req, stream, small, req, big, workspace, config); }); }); -#else - MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, { - typedef typename std::conditional<safe_acc, AType, DataType>::type AccType; - MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { - typedef typename std::conditional<safe_acc, OType, DataType>::type OutType; - config = ConfigureReduceImpl<ndim, AccType>(small.shape_, big.shape_, NULL, NULL); - ReduceImpl<Reducer, ndim, AccType, DataType, OutType, OP>( - stream, small, req, big, workspace, config); - }); - }); -#endif } else { ReduceImpl<Reducer, ndim, DType, DType, DType, OP>(stream, small, req, big, workspace, config); } diff --git a/src/operator/tensor/broadcast_reduce-inl.h b/src/operator/tensor/broadcast_reduce-inl.h index 415059a..e203da2 100644 --- a/src/operator/tensor/broadcast_reduce-inl.h +++ b/src/operator/tensor/broadcast_reduce-inl.h @@ -241,28 +241,15 @@ void Reduce(Stream<cpu>* s, const TBlob& small, const OpReqType req, N, M, req == kAddTo, big.dptr<DType>(), small.dptr<DType>(), big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride); } else { - // TODO(haojin2): Use real-only type swtich for windows temporarily due to CI issues. -#ifndef _WIN32 MXNET_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, { typedef typename std::conditional<safe_acc, AType, DataType>::type AccType; - MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(small.type_flag_, OType, { typedef typename std::conditional<safe_acc, OType, DataType>::type OutType; seq_reduce_compute<Reducer, ndim, AccType, DataType, OutType, OP>( N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<OutType>(), big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride); }); }); -#else - MXNET_REAL_ACC_TYPE_SWITCH(mshadow::DataType<DType>::kFlag, DataType, AType, { - typedef typename std::conditional<safe_acc, AType, DataType>::type AccType; - MSHADOW_TYPE_SWITCH(small.type_flag_, OType, { - typedef typename std::conditional<safe_acc, OType, DataType>::type OutType; - seq_reduce_compute<Reducer, ndim, AccType, DataType, OutType, OP>( - N, M, req == kAddTo, big.dptr<DataType>(), small.dptr<OutType>(), - big.shape_.get<ndim>(), small.shape_.get<ndim>(), rshape, rstride); - }); - }); -#endif } } diff --git a/src/operator/tensor/broadcast_reduce_op.h b/src/operator/tensor/broadcast_reduce_op.h index 414b606..27e2249 100644 --- a/src/operator/tensor/broadcast_reduce_op.h +++ b/src/operator/tensor/broadcast_reduce_op.h @@ -617,7 +617,7 @@ void ReduceAxesComputeImpl(const OpContext& ctx, BroadcastReduceShapeCompact(inputs[0].shape_, small, &src_shape, &dst_shape); Stream<xpu> *s = ctx.get_stream<xpu>(); MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { const TBlob in_data = inputs[0].reshape(src_shape); const TBlob out_data = outputs[0].reshape(dst_shape); BROADCAST_NDIM_SWITCH(dst_shape.ndim(), NDim, { @@ -1045,8 +1045,8 @@ inline void BroadcastComputeImpl(const nnvm::NodeAttrs& attrs, mxnet::TShape src_shape, dst_shape; BroadcastReduceShapeCompact(outputs[0].shape_, small, &dst_shape, &src_shape); Stream<xpu> *s = ctx.get_stream<xpu>(); - MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, IType, { - MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, OType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, IType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, OType, { mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> in_shape; mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> out_shape; for (int i = 0; i < MXNET_SPECIAL_MAX_NDIM; ++i) { diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index ad06df8..b48ed38 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -347,6 +347,9 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, } else { if (req[0] != kNullOp) { mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); + if (outputs[0].type_flag_ == mshadow::kBool) { + LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; + } MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { BROADCAST_NDIM_SWITCH(ndim, NDim, { mshadow::Shape<NDim> oshape = new_oshape.get<NDim>(); @@ -362,6 +365,35 @@ void BinaryBroadcastCompute(const nnvm::NodeAttrs& attrs, } template<typename xpu, typename OP> +void BinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + if (outputs[0].shape_.Size() == 0U) return; + mxnet::TShape new_lshape, new_rshape, new_oshape; + int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + ElemwiseBinaryOp::ComputeWithBool<xpu, OP>(attrs, ctx, inputs, req, outputs); + } else { + if (req[0] != kNullOp) { + mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape<NDim> oshape = new_oshape.get<NDim>(); + mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>()); + mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>()); + mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), outputs[0].dptr<DType>()); + }); + }); + } + } +} + +template<typename xpu, typename OP> void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector<TBlob>& inputs, diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index 6f444ae..c046a28 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -99,11 +99,13 @@ class ElemwiseBinaryOp : public OpBase { return a1.var() == a2.var(); } + public: /*! \brief Minimum of three */ static MSHADOW_XINLINE size_t minthree(const size_t a, const size_t b, const size_t c) { return a < b ? (a < c ? a : c) : (b < c ? b : c); } + private: template<typename xpu, typename LOP, typename ROP, typename DType> static void BackwardUseNone_(const nnvm::NodeAttrs &attrs, const OpContext &ctx, @@ -483,6 +485,9 @@ class ElemwiseBinaryOp : public OpBase { Stream<xpu> *s = ctx.get_stream<xpu>(); CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); + if (outputs[0].type_flag_ == mshadow::kBool) { + LOG(FATAL) << "Operator " << attrs.op->name << " does not support boolean type"; + } MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) @@ -498,6 +503,31 @@ class ElemwiseBinaryOp : public OpBase { } template<typename xpu, typename OP> + static void ComputeWithBool(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector<TBlob> &inputs, + const std::vector<OpReqType> &req, + const std::vector<TBlob> &outputs) { + using namespace mxnet_op; + if (req[0] != kNullOp) { + Stream<xpu> *s = ctx.get_stream<xpu>(); + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(outputs[0].type_flag_, DType, { + const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes; + if (size != 0) { + Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size, + outputs[0].dptr<DType>(), + inputs[0].dptr<DType>(), inputs[1].dptr<DType>()); + } + }); + }); + } + } + + template<typename xpu, typename OP> static void ComputeLogic(const nnvm::NodeAttrs &attrs, const OpContext &ctx, const std::vector<TBlob> &inputs, diff --git a/tests/nightly/JenkinsfileForBinaries b/tests/nightly/JenkinsfileForBinaries index 48db445..2b55c05 100755 --- a/tests/nightly/JenkinsfileForBinaries +++ b/tests/nightly/JenkinsfileForBinaries @@ -18,9 +18,9 @@ // //This is a Jenkinsfile for nightly tests. The format and some functions have been picked up from the top-level Jenkinsfile -mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' -mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' -mx_lib_cpp_example_mkl = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, build/cpp-package/example/imagenet_inference, lib/libmkldnn.so.1' +mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so' +mx_lib_cpp_example_mkl = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, build/cpp-package/example/imagenet_inference, lib/libmkldnn.so.1' node('utility') { // Loading the utilities requires a node context unfortunately diff --git a/tests/nightly/model_backwards_compatibility_check/JenkinsfileForMBCC b/tests/nightly/model_backwards_compatibility_check/JenkinsfileForMBCC index 725261d..e419aa7 100644 --- a/tests/nightly/model_backwards_compatibility_check/JenkinsfileForMBCC +++ b/tests/nightly/model_backwards_compatibility_check/JenkinsfileForMBCC @@ -18,7 +18,7 @@ // //This is a Jenkinsfile for the model backwards compatibility checker. The format and some functions have been picked up from the top-level Jenkinsfile. -mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so,lib/libtvmop.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' +mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a' node('restricted-utility') { // Loading the utilities requires a node context unfortunately diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py index 6adf935..5606eb1 100644 --- a/tests/python/gpu/test_fusion.py +++ b/tests/python/gpu/test_fusion.py @@ -28,6 +28,7 @@ from common import with_seed def check_fused_symbol(sym, **kwargs): inputs = sym.list_inputs() shapes = {inp : kwargs[inp].shape for inp in inputs} + ctx = kwargs.get('ctx', mx.gpu(0)) # Double identity so that there is always something to fuse test_sym = mx.sym.Group([mx.sym.identity(mx.sym.identity(s)) for s in sym]) rtol = {'float16' : 1e-2, @@ -43,9 +44,9 @@ def check_fused_symbol(sym, **kwargs): for grad_req in ['write', 'add']: type_dict = {inp : dtype for inp in inputs} os.environ["MXNET_USE_FUSION"] = "0" - orig_exec = test_sym.simple_bind(ctx=mx.gpu(0), grad_req=grad_req, type_dict=type_dict, **shapes) + orig_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes) os.environ["MXNET_USE_FUSION"] = "1" - fused_exec = test_sym.simple_bind(ctx=mx.gpu(0), grad_req=grad_req, type_dict=type_dict, **shapes) + fused_exec = test_sym.simple_bind(ctx=ctx, grad_req=grad_req, type_dict=type_dict, **shapes) fwd_orig = orig_exec.forward(is_train=True, **data) out_grads = [mx.nd.ones_like(arr) for arr in fwd_orig] orig_exec.backward(out_grads=out_grads) @@ -218,6 +219,26 @@ def test_fusion(): check_binary_ops() check_other_ops() +@with_seed() +def test_fusion_compiler_cache(): + # Stresses the internal cache of CUfunctions by creating the same kernel multiple times and + # on multiple GPUs if available. + a = mx.sym.Variable('a') + b = mx.sym.Variable('b') + shape = rand_shape_2d() + arr1 = mx.random.uniform(shape=shape) + arr2 = mx.random.uniform(shape=shape) + + # Invoke the same model twice, second time will exercise compile cache + check_fused_symbol(a+b, ctx=mx.gpu(0), a=arr1, b=arr2) + check_fused_symbol(a+b, ctx=mx.gpu(0), a=arr1, b=arr2) + + # On multi-GPU systems, invoke the same model on other GPUs + num_gpus = mx.context.num_gpus() + if num_gpus > 1: + check_fused_symbol(a+b, ctx=mx.gpu(1), a=arr1, b=arr2) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index 12e89a2..15914db 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -179,6 +179,23 @@ def test_np_get_constant(): assert_almost_equal(out.asnumpy(), (x.asnumpy() + const_arr), atol=1e-5, rtol=1e-4, use_broadcast=False) +@use_np +def test_parameters_zero_grad(): + for hybridize in [False, True]: + net = gluon.nn.HybridSequential() + for _ in range(5): + net.add(gluon.nn.Dense(10)) + if hybridize: + net.hybridize() + net.initialize() + out = net(mx.np.ones((32, 8))) + for v in net.collect_params().values(): + v.grad()[()] = 1 + net.collect_params().zero_grad() + for v in net.collect_params().values(): + assert_almost_equal(v.grad().asnumpy(), mx.np.zeros_like(v.grad()).asnumpy()) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 239f300..8e46f03 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -27,7 +27,7 @@ from mxnet import np, npx, autograd from mxnet.gluon import HybridBlock from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, use_np from common import with_seed, TemporaryDirectory -from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, assert_exception, is_op_runnable +from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, assert_exception, is_op_runnable, collapse_sum_like from mxnet.ndarray.ndarray import py_slice from mxnet.base import integer_types import scipy.stats as ss @@ -281,6 +281,62 @@ def test_np_ndarray_binary_element_wise_ops(): '<=': _np.less_equal }) + def _get_grad_func(op, scalar=None, reverse=False): + if op == '+': + if scalar is None: + return lambda ograd, x1, x2, out: (collapse_sum_like(ograd, x1.shape), + collapse_sum_like(ograd, x2.shape)) + elif not reverse: + return lambda ograd, x1, x2, out: ograd + else: + return lambda ograd, x1, x2, out: ograd + elif op == '-': + if scalar is None: + return lambda ograd, x1, x2, out: (collapse_sum_like(ograd, x1.shape), + -collapse_sum_like(ograd, x2.shape)) + elif not reverse: + return lambda ograd, x1, x2, out: ograd + else: + return lambda ograd, x1, x2, out: -ograd + elif op == '*': + if scalar is None: + return lambda ograd, x1, x2, out: (collapse_sum_like(ograd * x2, x1.shape), + collapse_sum_like(ograd * x1, x2.shape)) + elif not reverse: + return lambda ograd, x1, x2, out: ograd * x2 + else: + return lambda ograd, x1, x2, out: ograd * x1 + elif op == '/': + if scalar is None: + return lambda ograd, x1, x2, out: (collapse_sum_like(ograd / x2, x1.shape), + collapse_sum_like(-x1 * ograd / (x2 * x2), x2.shape)) + elif not reverse: + return lambda ograd, x1, x2, out: ograd / x2 + else: + return lambda ograd, x1, x2, out: -x1 * ograd / (x2 * x2) + elif op == 'mod': + if scalar is None: + return lambda ograd, x1, x2, out: (collapse_sum_like(ograd, x1.shape), + collapse_sum_like(-ograd * _np.floor(x1 / x2), x2.shape)) + elif not reverse: + return lambda ograd, x1, x2, out: ograd + else: + return lambda ograd, x1, x2, out: -ograd * _np.floor(x1 / x2) + elif op == 'pow': + if scalar is None: + return lambda ograd, x1, x2, out: (collapse_sum_like(ograd * x2 * _np.power(x1, x2 - 1), x1.shape), + collapse_sum_like(ograd * out * _np.log(x1), x2.shape)) + elif not reverse: + return lambda ograd, x1, x2, out: ograd * x2 * _np.power(x1, x2 - 1) + else: + return lambda ograd, x1, x2, out: ograd * out * _np.log(x1) + elif op in ('==', '!=', '<', '<=', '>', '>='): + if scalar is None: + return lambda ograd, x1, x2, out: (_np.zeros_like(x1), _np.zeros_like(x2)) + else: + return lambda ograd, x1, x2, out: _np.zeros_like(ograd) + return None + def get_np_ret(x1, x2, op): return np_op_map[op](x1, x2) @@ -364,13 +420,15 @@ def test_np_ndarray_binary_element_wise_ops(): mx_input1 = abs(_np.random.uniform()) + 1 np_input1 = mx_input1 else: - mx_input1 = rand_ndarray(shape1, dtype=dtype).abs() + 1 + mx_input1 = (rand_ndarray(shape1, dtype=dtype).abs() + 1).as_np_ndarray() + mx_input1.attach_grad() np_input1 = mx_input1.asnumpy() if shape2 is None: mx_input2 = abs(_np.random.uniform()) + 1 np_input2 = mx_input2 else: - mx_input2 = rand_ndarray(shape2, dtype=dtype).abs() + 1 + mx_input2 = (rand_ndarray(shape2, dtype=dtype).abs() + 1).as_np_ndarray() + mx_input2.attach_grad() np_input2 = mx_input2.asnumpy() scalar = None @@ -382,7 +440,9 @@ def test_np_ndarray_binary_element_wise_ops(): scalar = mx_input1 reverse = True + grad_func = _get_grad_func(op, scalar, reverse) np_out = get_np_ret(np_input1, np_input2, op) + ograd = _np.ones_like(np_out) for hybridize in [True, False]: if scalar is None: get_mx_ret_np = TestBinaryElementWiseOp(op) @@ -390,26 +450,49 @@ def test_np_ndarray_binary_element_wise_ops(): if hybridize: get_mx_ret_np.hybridize() get_mx_ret_classic.hybridize() - mx_out = get_mx_ret_np(mx_input1.as_np_ndarray(), mx_input2.as_np_ndarray()) + if grad_func is None: + mx_out = get_mx_ret_np(mx_input1, mx_input2) + else: + with mx.autograd.record(): + mx_out = get_mx_ret_np(mx_input1, mx_input2) + mx_out.backward() assert type(mx_out) == np.ndarray - assert np_out.shape == mx_out.shape if op in logic_ops: assert np_out.dtype == mx_out.dtype - assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5) + assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5, use_broadcast=False) + + if grad_func is not None: + x1_grad_expected, x2_grad_expected = grad_func(ograd, np_input1, np_input2, np_out) + assert_almost_equal(mx_input1.grad.asnumpy(), x1_grad_expected, atol=1e-5, rtol=1e-3, + use_broadcast=False) + assert_almost_equal(mx_input2.grad.asnumpy(), x2_grad_expected, atol=1e-5, rtol=1e-3, + use_broadcast=False) else: get_mx_ret = TestBinaryElementWiseOp(op, scalar=scalar, reverse=reverse) if hybridize: get_mx_ret.hybridize() if reverse: - mx_out = get_mx_ret(mx_input2.as_np_ndarray()) - assert type(mx_out) == np.ndarray + mx_input = mx_input2 else: - mx_out = get_mx_ret(mx_input1.as_np_ndarray()) - assert type(mx_out) == np.ndarray - assert np_out.shape == mx_out.shape + mx_input = mx_input1 + + if grad_func is None: + mx_out = get_mx_ret(mx_input) + else: + with mx.autograd.record(): + mx_out = get_mx_ret(mx_input) + mx_out.backward() + assert type(mx_out) == np.ndarray + if op in logic_ops: assert np_out.dtype == mx_out.dtype - assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5) + assert_almost_equal(mx_out.asnumpy(), np_out, atol=1e-6, rtol=1e-5, use_broadcast=False) + + # check grad + if grad_func is not None: + x_grad_expected = grad_func(ograd, np_input1, np_input2, np_out) + assert_almost_equal(mx_input.grad.asnumpy(), x_grad_expected, atol=1e-5, rtol=1e-3, + use_broadcast=False) dtypes = [_np.float32, _np.float64, None] ops = np_op_map.keys() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index c1a6ed5..9a36e06 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -614,52 +614,83 @@ def test_np_mean(): def is_int(dtype): return 'int' in dtype + is_windows = sys.platform.startswith('win') in_data_dim = random.choice([2, 3, 4]) shape = rand_shape_nd(in_data_dim, dim=3) acc_type = {'float16': 'float32', 'float32': 'float64', 'float64': 'float64', - 'int8': 'int32', 'int32': 'int64', 'int64': 'int64'} + 'bool': 'int64', 'int8': 'int32', 'int32': 'int64', 'int64': 'int64'} + ft_types = ['float16', 'float32', 'float64'] + it_types = ['bool', 'int8', 'int32', 'int64'] for hybridize in [False, True]: for keepdims in [True, False]: for axis in ([i for i in range(in_data_dim)] + [(), None]): - for itype in ['float16', 'float32', 'float64']: - for dtype in ['float16', 'float32', 'float64']: - if is_int(dtype) and not is_int(itype): - continue - # test gluon - test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims) - if hybridize: - test_mean.hybridize() - if is_int(itype): - x = _np.random.randint(-128, 128, shape, dtype=itype) - x = mx.nd.array(x, dtype=itype) - else: - x = mx.nd.random.uniform(-1.0, 1.0, shape=shape, dtype=itype) - x = x.as_np_ndarray() - x.attach_grad() + for itype, dtype in itertools.product(ft_types, [None] + ft_types + it_types): + if dtype == 'bool': + continue + # test gluon + test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims) + if hybridize: + test_mean.hybridize() + x = np.random.uniform(-1.0, 1.0, size=shape).astype(itype) + x = x.as_np_ndarray() + x.attach_grad() - expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims) - expected_ret = expected_ret.astype(dtype) - with mx.autograd.record(): - y = test_mean(x) - assert y.shape == expected_ret.shape - assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3, - atol=1e-5 if dtype == 'float16' else 1e-5) + expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims) + expected_ret = expected_ret.astype(dtype) + with mx.autograd.record(): + y = test_mean(x) + assert y.shape == expected_ret.shape + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3, + atol=1e-5 if dtype == 'float16' else 1e-5) - y.backward() - N = x.size / y.size - assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype) / N) + y.backward() + N = x.size / y.size + assert same(x.grad.asnumpy(), _np.ones(shape=x.shape, dtype=x.dtype) / N) - # test numeric - if itype == 'float32' and dtype == 'float32': - x_sym = mx.sym.Variable("x").as_np_ndarray() - mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray() - check_numeric_gradient(mx_sym, [x.as_nd_ndarray()], - numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32) + # test numeric + if itype == 'float32' and dtype == 'float32': + x_sym = mx.sym.Variable("x").as_np_ndarray() + mx_sym = mx.sym.np.mean(x_sym, axis=axis, dtype=dtype, keepdims=keepdims).as_nd_ndarray() + check_numeric_gradient(mx_sym, [x.as_nd_ndarray()], + numeric_eps=1e-3, rtol=1e-3, atol=1e-4, dtype=_np.float32) - # test imperative - mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims) - np_out = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype) - assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + # test imperative + mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims) + np_out = _np.mean(x.asnumpy(), axis=axis, dtype=acc_type[itype], keepdims=keepdims).astype(dtype) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + for itype, dtype in itertools.product(it_types, [None] + ft_types + it_types): + if dtype == 'bool': + continue + # test gluon + test_mean = TestMean(axis=axis, dtype=dtype, keepdims=keepdims) + if hybridize: + test_mean.hybridize() + + if itype == 'bool': + x = np.array(_np.random.uniform(size=shape) > 0.5) + else: + x = np.random.uniform(-128, 127, size=shape).astype(itype) + + expected_ret = _np.mean(x.asnumpy(), axis=axis, dtype=dtype, keepdims=keepdims) + + if itype == 'bool': + if is_op_runnable() and (not is_windows) and dtype not in ['float16', 'int8']: # special handling of boolean ndarray + y = test_mean(x) + assert y.shape == expected_ret.shape + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3, + atol=1e-5 if dtype == 'float16' else 1e-5) + continue + + y = test_mean(x) + assert y.shape == expected_ret.shape + assert_almost_equal(y.asnumpy(), expected_ret, rtol=1e-3 if dtype == 'float16' else 1e-3, + atol=1e-5 if dtype == 'float16' else 1e-5) + + # test imperative + mx_out = np.mean(x, axis=axis, dtype=dtype, keepdims=keepdims) + np_out = _np.mean(x.asnumpy(), axis=axis, dtype=dtype, keepdims=keepdims).astype(dtype) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) @with_seed() @@ -1572,8 +1603,8 @@ def test_np_binary_funcs(): rtol=1e-1, atol=1e-2, equal_nan=True, use_broadcast=False) if rgrads is None: assert_almost_equal(mx_test_x2.grad.asnumpy(), - collapse_sum_like(rgrad(y.asnumpy(), np_test_x2, np_test_x1), mx_test_x2.shape), - rtol=1e-1, atol=1e-2, equal_nan=True, use_broadcast=False) + collapse_sum_like(rgrad(y.asnumpy(), np_test_x2, np_test_x1), mx_test_x2.shape), + rtol=1e-1, atol=1e-2, equal_nan=True, use_broadcast=False) else: assert_almost_equal(mx_test_x2.grad.asnumpy(), collapse_sum_like(rgrad(y.asnumpy(), np_test_x1, np_test_x2), mx_test_x2.shape), @@ -1594,7 +1625,6 @@ def test_np_binary_funcs(): assertRaises(NotImplementedError, getattr(np, func), mx_test_x1, mx_test_x2, order='C') assertRaises(NotImplementedError, getattr(np, func), mx_test_x1, mx_test_x2, order='mxnet') - funcs = { 'add': (-1.0, 1.0, [lambda y, x1, x2: _np.ones(y.shape)], None), 'subtract': @@ -1603,7 +1633,7 @@ def test_np_binary_funcs(): 'multiply': (-1.0, 1.0, [lambda y, x1, x2: _np.broadcast_to(x2, y.shape)], [lambda y, x1, x2: _np.broadcast_to(x1, y.shape)]), 'divide': (0.1, 1.0, [lambda y, x1, x2: _np.ones(y.shape) / x2], - [lambda y, x1, x2: -x1 / (x2 * x2)]), + [lambda y, x1, x2: -x1 / (x2 * x2)]), 'mod': (1.0, 10.0, [lambda y, x1, x2: _np.ones(y.shape), lambda y, x1, x2: _np.zeros(y.shape)], @@ -1652,6 +1682,125 @@ def test_np_binary_funcs(): @with_seed() @use_np +def test_np_mixed_precision_binary_funcs(): + def check_mixed_precision_binary_func(func, low, high, lshape, rshape, ltype, rtype): + class TestMixedBinary(HybridBlock): + def __init__(self, func): + super(TestMixedBinary, self).__init__() + self._func = func + + def hybrid_forward(self, F, a, b, *args, **kwargs): + return getattr(F.np, self._func)(a, b) + + np_func = getattr(_np, func) + mx_func = TestMixedBinary(func) + np_test_x1 = _np.random.uniform(low, high, lshape).astype(ltype) + np_test_x2 = _np.random.uniform(low, high, rshape).astype(rtype) + mx_test_x1 = mx.numpy.array(np_test_x1, dtype=ltype) + mx_test_x2 = mx.numpy.array(np_test_x2, dtype=rtype) + rtol = 1e-2 if ltype is np.float16 or rtype is np.float16 else 1e-3 + atol = 1e-4 if ltype is np.float16 or rtype is np.float16 else 1e-5 + for hybridize in [True, False]: + if hybridize: + mx_func.hybridize() + np_out = np_func(np_test_x1, np_test_x2) + with mx.autograd.record(): + y = mx_func(mx_test_x1, mx_test_x2) + assert y.shape == np_out.shape + assert_almost_equal(y.asnumpy(), np_out.astype(y.dtype), rtol=rtol, atol=atol, + use_broadcast=False, equal_nan=True) + + np_out = getattr(_np, func)(np_test_x1, np_test_x2) + mx_out = getattr(mx.np, func)(mx_test_x1, mx_test_x2) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out.astype(mx_out.dtype), rtol=rtol, atol=atol, + use_broadcast=False, equal_nan=True) + + funcs = { + 'add': (-1.0, 1.0), + 'subtract': (-1.0, 1.0), + 'multiply': (-1.0, 1.0), + } + + shape_pairs = [((3, 2), (3, 2)), + ((3, 2), (3, 1)), + ((3, 1), (3, 0)), + ((0, 2), (1, 2)), + ((2, 3, 4), (3, 1)), + ((2, 3), ()), + ((), (2, 3))] + + itypes = [np.bool, np.int8, np.int32, np.int64] + ftypes = [np.float16, np.float32, np.float64] + for func, func_data in funcs.items(): + low, high = func_data + for lshape, rshape in shape_pairs: + for type1, type2 in itertools.product(itypes, ftypes): + check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2) + check_mixed_precision_binary_func(func, low, high, lshape, rshape, type2, type1) + + for type1, type2 in itertools.product(ftypes, ftypes): + if type1 == type2: + continue + check_mixed_precision_binary_func(func, low, high, lshape, rshape, type1, type2) + + +@with_seed() +@use_np +def test_np_boolean_binary_funcs(): + def check_boolean_binary_func(func, mx_x1, mx_x2): + class TestBooleanBinary(HybridBlock): + def __init__(self, func): + super(TestBooleanBinary, self).__init__() + self._func = func + + def hybrid_forward(self, F, a, b, *args, **kwargs): + return getattr(F.np, self._func)(a, b) + + np_x1 = mx_x1.asnumpy() + np_x2 = mx_x2.asnumpy() + np_func = getattr(_np, func) + mx_func = TestBooleanBinary(func) + for hybridize in [True, False]: + if hybridize: + mx_func.hybridize() + np_out = np_func(np_x1, np_x2) + with mx.autograd.record(): + y = mx_func(mx_x1, mx_x2) + assert y.shape == np_out.shape + assert_almost_equal(y.asnumpy(), np_out.astype(y.dtype), rtol=1e-3, atol=1e-20, + use_broadcast=False, equal_nan=True) + + np_out = getattr(_np, func)(np_x1, np_x2) + mx_out = getattr(mx.np, func)(mx_x1, mx_x2) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out.astype(mx_out.dtype), rtol=1e-3, atol=1e-20, + use_broadcast=False, equal_nan=True) + + + funcs = [ + 'add', + 'multiply', + 'true_divide', + ] + + shape_pairs = [((3, 2), (3, 2)), + ((3, 2), (3, 1)), + ((3, 1), (3, 0)), + ((0, 2), (1, 2)), + ((2, 3, 4), (3, 1)), + ((2, 3), ()), + ((), (2, 3))] + + for lshape, rshape in shape_pairs: + for func in funcs: + x1 = np.array(_np.random.uniform(size=lshape) > 0.5) + x2 = np.array(_np.random.uniform(size=rshape) > 0.5) + check_boolean_binary_func(func, x1, x2) + + +@with_seed() +@use_np def test_npx_relu(): def np_relu(x): return _np.maximum(x, 0.0)