This is an automated email from the ASF dual-hosted git repository. reminisce pushed a commit to branch numpy in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/numpy by this push: new 2e10193 Enable np op compat check with name prefix (#14897) 2e10193 is described below commit 2e101935a3cdc7738dffa1bde1ef5b8fa7e31fc7 Author: reminisce <wujun....@gmail.com> AuthorDate: Mon May 6 16:56:36 2019 -0700 Enable np op compat check with name prefix (#14897) --- src/c_api/c_api_common.h | 17 ++++++++++++++++- src/operator/numpy/np_broadcast_reduce_op_value.cc | 3 +-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/c_api/c_api_common.h b/src/c_api/c_api_common.h index 118341d..ab1f5f7 100644 --- a/src/c_api/c_api_common.h +++ b/src/c_api/c_api_common.h @@ -163,10 +163,25 @@ inline void CopyAttr(const nnvm::IndexedGraph& idx, extern const std::vector<std::string> kHiddenKeys; } // namespace mxnet +/*! + * An operator is considered as numpy compatible if it satisfies either one + * of the following conditions. + * 1. The op has the attribute mxnet::TIsNumpyCompatible> registered as True. + * 2. The op's name starts with the prefix _numpy_. + * The first condition is usually for the ops registered as internal ops, such + * as _np_add, _true_divide, etc. They are wrapped by some user-facing op + * APIs in the Python end. + * The second condition is for the ops registered in the backend while exposed + * directly to users as is, such as _numpy_sum etc. + */ inline bool IsNumpyCompatOp(const nnvm::Op* op) { static const auto& is_np_compat = nnvm::Op::GetAttr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible"); - return is_np_compat.get(op, false); + if (is_np_compat.get(op, false)) { + return true; + } + static const std::string prefix = "_numpy_"; + return op->name.find(prefix.c_str(), 0, prefix.size()) != std::string::npos; } #endif // MXNET_C_API_C_API_COMMON_H_ diff --git a/src/operator/numpy/np_broadcast_reduce_op_value.cc b/src/operator/numpy/np_broadcast_reduce_op_value.cc index 13b575a..6c81bf6 100644 --- a/src/operator/numpy/np_broadcast_reduce_op_value.cc +++ b/src/operator/numpy/np_broadcast_reduce_op_value.cc @@ -65,8 +65,7 @@ NNVM_REGISTER_OP(_numpy_sum) [](const NodeAttrs& attrs) { return std::vector<ResourceRequest>{ResourceRequest::kTempSpace}; }) -.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"}) -.set_attr<mxnet::TIsNumpyCompatible>("TIsNumpyCompatible", true); +.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_numpy_sum"}); NNVM_REGISTER_OP(_backward_numpy_sum) .set_num_outputs(1)