This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new c59a325  fix mixed type binary logic operators (#18427)
c59a325 is described below

commit c59a3255346ebe9bc0729c5a702fc99624ed2374
Author: Yijun Chen <chenyijun0...@gmail.com>
AuthorDate: Tue Jun 2 15:31:16 2020 +0800

    fix mixed type binary logic operators (#18427)
---
 src/operator/mshadow_op.h                          |  6 ++--
 src/operator/mxnet_op.h                            |  7 +++++
 .../numpy/np_elemwise_broadcast_logic_op.cc        |  2 --
 src/operator/tensor/elemwise_binary_broadcast_op.h | 34 ++++++++++++++++++----
 src/operator/tensor/elemwise_binary_op.h           |  4 ++-
 tests/python/unittest/test_numpy_op.py             |  9 ++++++
 6 files changed, 52 insertions(+), 10 deletions(-)

diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 4cbb17d..9069af9 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -114,8 +114,10 @@ using std::is_integral;
 
 #define MXNET_BINARY_LOGIC_OP_NC(name, expr) \
   struct name : public mxnet_op::tunable  { \
-    template<typename DType> \
-    MSHADOW_XINLINE static bool Map(DType a, DType b) { \
+    template<typename DType, typename EType> \
+    MSHADOW_XINLINE static bool Map(DType lhs, EType rhs) { \
+      double a = static_cast<double>(lhs); \
+      double b = static_cast<double>(rhs); \
       return (expr); \
     } \
   }
diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h
index 3f1c804..bc8c0af 100644
--- a/src/operator/mxnet_op.h
+++ b/src/operator/mxnet_op.h
@@ -860,6 +860,13 @@ struct op_with_req {
     KERNEL_ASSIGN(out[i], req, OP::Map(in[i], value));
   }
 
+  /*! \brief input is two tensors with different type and with a boolean 
output tensor */
+  template<typename LType, typename RType,
+           typename std::enable_if<!std::is_same<LType, RType>::value, 
int>::type = 0>
+  MSHADOW_XINLINE static void Map(index_t i, bool *out, const LType *lhs, 
const RType *rhs) {
+    KERNEL_ASSIGN(out[i], req, OP::Map(lhs[i], rhs[i]));
+  }
+
 #ifndef _WIN32
   /*! \brief inputs are two tensors with a half_t output tensor */
   template<typename DType,
diff --git a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc 
b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc
index 8c8320d..b191553 100644
--- a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc
@@ -64,8 +64,6 @@ bool NumpyBinaryLogicOpType(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_attrs->size(), 2U);
   CHECK_EQ(out_attrs->size(), 1U);
   if (in_attrs->at(0) == -1 && in_attrs->at(1) == -1) return false;
-  TYPE_ASSIGN_CHECK(*in_attrs, 0, in_attrs->at(1));
-  TYPE_ASSIGN_CHECK(*in_attrs, 1, in_attrs->at(0));
   TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool);
   return true;
 }
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h 
b/src/operator/tensor/elemwise_binary_broadcast_op.h
index ffd0f12..6f6711e 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -209,6 +209,25 @@ struct binary_broadcast_kernel {
   }
 
   /*! \brief Map function for binary_broadcast_kernel */
+  template<typename LType, typename RType, typename OType>
+  MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req,
+                                  const Shape <ndim> &lstride, const Shape 
<ndim> &rstride,
+                                  const Shape <ndim> &oshape, LType *lhs, 
RType *rhs,
+                                  OType *out) {
+    Shape <ndim> coord = unravel(base, oshape);
+    auto lidx = static_cast<index_t>(dot(coord, lstride));
+    auto ridx = static_cast<index_t>(dot(coord, rstride));
+    KERNEL_ASSIGN(out[base], req, OP::Map(lhs[lidx], rhs[ridx]));
+    // starts from 1 to avoid extra inc at end of loop
+    for (index_t i = 1; i < length; ++i) {
+      inc(&coord, oshape, &lidx, lstride, &ridx, rstride);
+      // When tuning, don't actually run the op, since it's not going to be 
tuned against
+      // the actual op we'll eventually be using
+      KERNEL_ASSIGN(out[base + i], req, OP::Map(lhs[lidx], rhs[ridx]));
+    }
+  }
+
+  /*! \brief Map function for binary_broadcast_kernel */
   template<typename IType, typename DType>
   MSHADOW_XINLINE static void Map(index_t base, index_t length, OpReqType req,
                                   const Shape <ndim> &lstride, const Shape 
<ndim> &rstride,
@@ -430,23 +449,28 @@ void BinaryBroadcastComputeLogic(const nnvm::NodeAttrs& 
attrs,
                                  const std::vector<TBlob>& outputs) {
   if (outputs[0].shape_.Size() == 0U) return;
   mxnet::TShape new_lshape, new_rshape, new_oshape;
-  int ndim = BinaryBroadcastShapeCompact(inputs[0].shape_, inputs[1].shape_, 
outputs[0].shape_,
+  const TBlob& lhs = inputs[0];
+  const TBlob& rhs = inputs[1];
+  const TBlob& out = outputs[0];
+  int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
                                          &new_lshape, &new_rshape, 
&new_oshape);
   if (!ndim) {
     ElemwiseBinaryOp::ComputeLogic<xpu, OP>(attrs, ctx, inputs, req, outputs);
   } else {
     if (req[0] == kNullOp) return;
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
-    MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
-        BROADCAST_NDIM_SWITCH(ndim, NDim, {
+    MSHADOW_TYPE_SWITCH_WITH_BOOL(lhs.type_flag_, DType, {
+      MSHADOW_TYPE_SWITCH_WITH_BOOL(rhs.type_flag_, EType, {
+         BROADCAST_NDIM_SWITCH(ndim, NDim, {
           mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
           mshadow::Shape<NDim> lstride = 
mxnet_op::calc_stride(new_lshape.get<NDim>());
           mshadow::Shape<NDim> rstride = 
mxnet_op::calc_stride(new_rshape.get<NDim>());
           mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
           template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, 
oshape,
-                            inputs[0].dptr<DType>(), inputs[1].dptr<DType>(),
-                            outputs[0].dptr<bool>());
+                            lhs.dptr<DType>(), rhs.dptr<EType>(),
+                            out.dptr<bool>());
         });
+      });
     });
   }
 }
diff --git a/src/operator/tensor/elemwise_binary_op.h 
b/src/operator/tensor/elemwise_binary_op.h
index 01dca2e..c080570 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -620,14 +620,16 @@ template<typename xpu, typename OP>
     CHECK_EQ(outputs.size(), 1U);
     MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
       MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[0].type_flag_, DType, {
+        MSHADOW_TYPE_SWITCH_WITH_BOOL(inputs[1].type_flag_, EType, {
           const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), 
inputs[1].Size())
           + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
           if (size != 0) {
             Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
                                                                 
outputs[0].dptr<bool>(),
                                                                 
inputs[0].dptr<DType>(),
-                                                                
inputs[1].dptr<DType>());
+                                                                
inputs[1].dptr<EType>());
           }
+        });
       });
     });
   }
diff --git a/tests/python/unittest/test_numpy_op.py 
b/tests/python/unittest/test_numpy_op.py
index 214720e..572dd34 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -2781,6 +2781,15 @@ def test_np_mixed_precision_binary_funcs():
         'mod': (1.0, 5.0, None, None),
         'power': (1.0, 3.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2,
                             lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)),
+        'equal': (0.0, 2.0, None, None),
+        'not_equal': (0.0, 2.0, None, None),
+        'greater': (0.0, 2.0, None, None),
+        'less': (0.0, 2.0, None, None),
+        'greater_equal': (0.0, 2.0, None, None),
+        'less_equal': (0.0, 2.0, None, None),
+        'logical_and': (0.0, 2.0, None, None),
+        'logical_or': (0.0, 2.0, None, None),
+        'logical_xor': (0.0, 2.0, None, None),
     }
 
     shape_pairs = [((3, 2), (3, 2)),

Reply via email to