This is an automated email from the ASF dual-hosted git repository. liuyizhi pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new c59a325 fix mixed type binary logic operators (#18427) c59a325 is described below commit c59a3255346ebe9bc0729c5a702fc99624ed2374 Author: Yijun Chen <chenyijun0...@gmail.com> AuthorDate: Tue Jun 2 15:31:16 2020 +0800 fix mixed type binary logic operators (#18427) --- src/operator/mshadow_op.h | 6 ++-- src/operator/mxnet_op.h | 7 +++++ .../numpy/np_elemwise_broadcast_logic_op.cc | 2 -- src/operator/tensor/elemwise_binary_broadcast_op.h | 34 ++++++++++++++++++---- src/operator/tensor/elemwise_binary_op.h | 4 ++- tests/python/unittest/test_numpy_op.py | 9 ++++++ 6 files changed, 52 insertions(+), 10 deletions(-) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 4cbb17d..9069af9 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -114,8 +114,10 @@ using std::is_integral; #define MXNET_BINARY_LOGIC_OP_NC(name, expr) \ struct name : public mxnet_op::tunable { \ - template<typename DType> \ - MSHADOW_XINLINE static bool Map(DType a, DType b) { \ + template<typename DType, typename EType> \ + MSHADOW_XINLINE static bool Map(DType lhs, EType rhs) { \ + double a = static_cast<double>(lhs); \ + double b = static_cast<double>(rhs); \ return (expr); \ } \ } diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index 3f1c804..bc8c0af 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -860,6 +860,13 @@ struct op_with_req { KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value)); } + /*! \brief input is two tensors with different type and with a boolean output tensor */ + template<typename LType, typename RType, + typename std::enable_if<!std::is_same<LType, RType>::value, int>::type = 0> + MSHADOW_XINLINE static void Map(index_t i, bool *out, const LType *lhs, const RType *rhs) { + KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i])); + } + #ifndef _WIN32 /*! \brief inputs are two tensors with a half_t output tensor */ template<typename DType, diff --git a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc index 8c8320d..b191553 100644 --- a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc @@ -64,8 +64,6 @@ bool NumpyBinaryLogicOpType(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); if (in_attrs->at(0) == -1 && in_attrs->at(1) == -1) return false; - TYPE_ASSIGN_CHECK(*in_attrs, 0, in_attrs->at(1)); - TYPE_ASSIGN_CHECK(*in_attrs, 1, in_attrs->at(0)); TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool); return true; } diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h b/src/operator/tensor/elemwise_binary_broadcast_op.h index ffd0f12..6f6711e 100644 --- a/src/operator/tensor/elemwise_binary_broadcast_op.h +++ b/src/operator/tensor/elemwise_binary_broadcast_op.h @@ -209,6 +209,25 @@ struct binary_broadcast_kernel { } /*! \brief Map function for binary_broadcast_kernel */ + template<typename LType, typename RType, typename OType> + MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, + const Shape <ndim> &lstride, const Shape <ndim> &rstride, + const Shape <ndim> &oshape, LType *lhs, RType *rhs, + OType *out) { + Shape <ndim> coord = unravel(base, oshape); + auto lidx = static_cast<index_t>(dot(coord, lstride)); + auto ridx = static_cast<index_t>(dot(coord, rstride)); + KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx])); + // starts from 1 to avoid extra inc at end of loop + for (index_t i = 1; i < length; ++i) { + inc(&coord, oshape, &lidx, lstride, &ridx, rstride); + // When tuning, don't actually run the op, since it's not going to be tuned against + // the actual op we'll eventually be using + KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx])); + } + } + + /*! \brief Map function for binary_broadcast_kernel */ template<typename IType, typename DType> MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req, const Shape <ndim> &lstride, const Shape <ndim> &rstride, @@ -430,23 +449,28 @@ void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& attrs, const std::vector<TBlob>& outputs) { if (outputs[0].shape_.Size() == 0U) return; mxnet::TShape new_lshape, new_rshape, new_oshape; - int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, outputs[0].shape_, + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_, &new_lshape, &new_rshape, &new_oshape); if (!ndim) { ElemwiseBinaryOp::ComputeLogic<xpu, OP>(attrs, ctx, inputs, req, outputs); } else { if (req[0] == kNullOp) return; mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); - MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { - BROADCAST_NDIM_SWITCH(ndim, NDim, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(lhs.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(rhs.type_flag_, EType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { mshadow::Shape<NDim> oshape = new_oshape.get<NDim>(); mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>()); mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>()); mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>:: template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, - inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), - outputs[0].dptr<bool>()); + lhs.dptr<DType>(), rhs.dptr<EType>(), + out.dptr<bool>()); }); + }); }); } } diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index 01dca2e..c080570 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -620,14 +620,16 @@ template<typename xpu, typename OP> CHECK_EQ(outputs.size(), 1U); MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[1].type_flag_, EType, { const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size()) + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes; if (size != 0) { Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size, outputs[0].dptr<bool>(), inputs[0].dptr<DType>(), - inputs[1].dptr<DType>()); + inputs[1].dptr<EType>()); } + }); }); }); } diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 214720e..572dd34 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2781,6 +2781,15 @@ def test_np_mixed_precision_binary_funcs(): 'mod': (1.0, 5.0, None, None), 'power': (1.0, 3.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2, lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)), + 'equal': (0.0, 2.0, None, None), + 'not_equal': (0.0, 2.0, None, None), + 'greater': (0.0, 2.0, None, None), + 'less': (0.0, 2.0, None, None), + 'greater_equal': (0.0, 2.0, None, None), + 'less_equal': (0.0, 2.0, None, None), + 'logical_and': (0.0, 2.0, None, None), + 'logical_or': (0.0, 2.0, None, None), + 'logical_xor': (0.0, 2.0, None, None), } shape_pairs = [((3, 2), (3, 2)),