This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 9feecce  [MXNET-399] Elemwise_mul between dense and csr on CPU & GPU 
(#10894)
9feecce is described below

commit 9feeccecb4ab64461cfae0bd4e75dd4bcbd7c9d5
Author: Hao Jin <haoj...@users.noreply.github.com>
AuthorDate: Wed May 30 17:33:42 2018 -0700

    [MXNET-399] Elemwise_mul between dense and csr on CPU & GPU (#10894)
    
    * support elemwise_mul between dns and csr
    
    * address reviews and support for backward when ograd is dns
---
 src/operator/tensor/elemwise_binary_op-inl.h    |  85 +++++++++++++++++
 src/operator/tensor/elemwise_binary_op.cc       |  21 ++++
 src/operator/tensor/elemwise_binary_op.h        | 121 +++++++++++++++++-------
 src/operator/tensor/elemwise_binary_op_basic.cu |   4 +-
 tests/python/unittest/test_sparse_operator.py   |  14 ++-
 5 files changed, 210 insertions(+), 35 deletions(-)

diff --git a/src/operator/tensor/elemwise_binary_op-inl.h 
b/src/operator/tensor/elemwise_binary_op-inl.h
index c74f1f9..911c369 100644
--- a/src/operator/tensor/elemwise_binary_op-inl.h
+++ b/src/operator/tensor/elemwise_binary_op-inl.h
@@ -496,6 +496,91 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<xpu> *s,
 }
 
 /*!
+ * \brief Kernel for performing elemwise op between dense and csr matrix
+ * \param i            global thread id
+ * \param req          type of request
+ * \param out          output array
+ * \param dns_data     data array of dense input
+ * \param csr_data     data array of csr input
+ * \param csr_indices  indices array of csr input
+ * \param csr_indptr   indptr array of csr input
+ * \param num_rows     number of rows of both inputs
+ * \param num_cols     number of columns of both inputs
+ */
+template<int req, typename OP, bool reverse>
+struct ElemwiseDnsCsrCsrKernel {
+  template<typename DType, typename IType, typename CType>
+  MSHADOW_XINLINE static void Map(int i, DType* out, DType* dns_data,
+                                  const DType* csr_data, const IType* 
csr_indices,
+                                  const CType* csr_indptr, const nnvm::dim_t 
num_rows,
+                                  const nnvm::dim_t num_cols) {
+    if (i < num_rows) {
+      for (int j = csr_indptr[i]; j < csr_indptr[i+1]; ++j) {
+        KERNEL_ASSIGN(out[j], req, reverse ?
+                                   OP::Map(dns_data[i * num_cols + 
csr_indices[j]], csr_data[j]) :
+                                   OP::Map(csr_data[j], dns_data[i * num_cols 
+ csr_indices[j]]));
+      }
+    }
+  }
+};
+
+/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
+template<typename xpu, typename OP>
+void ElemwiseBinaryOp::DnsCsrCsrOp(const nnvm::NodeAttrs &attrs,
+                                   const OpContext &ctx,
+                                   const NDArray &dns,
+                                   const NDArray &csr,
+                                   const OpReqType req,
+                                   const NDArray &output,
+                                   const bool reverse) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  using namespace csr;
+  CHECK_EQ(dns.storage_type(), kDefaultStorage);
+  CHECK_EQ(csr.storage_type(), kCSRStorage);
+  CHECK_EQ(req, kWriteTo) << "elemwise(dns, csr) = csr only supports kWriteTo";
+  if (req == kNullOp) return;
+  const bool supported_op = std::is_same<OP, mshadow_op::mul>::value;
+  CHECK(supported_op == true) << "elemwise(dns, csr) = csr only supports mul";
+  const nnvm::dim_t num_csr_rows = csr.shape()[0];
+  const nnvm::dim_t num_csr_cols = csr.shape()[1];
+  const nnvm::dim_t nnz = csr.storage_shape()[0];
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+
+  output.CheckAndAlloc({Shape1(num_csr_rows + 1), Shape1(nnz)});
+  if (csr.storage_initialized()) {
+    TBlob csr_data = csr.data();
+    TBlob csr_indices = csr.aux_data(kIdx);
+    TBlob csr_indptr = csr.aux_data(kIndPtr);
+    MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, {
+      MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, {
+        MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {
+          MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+            if (reverse) {
+              Kernel<ElemwiseDnsCsrCsrKernel<Req, OP, true>, xpu>::Launch(
+                s, num_csr_rows, output.data().dptr<DType>(), 
dns.data().dptr<DType>(),
+                csr_data.dptr<DType>(), csr_indices.dptr<IType>(), 
csr_indptr.dptr<CType>(),
+                num_csr_rows, num_csr_cols);
+            } else {
+              Kernel<ElemwiseDnsCsrCsrKernel<Req, OP, false>, xpu>::Launch(
+                s, num_csr_rows, output.data().dptr<DType>(), 
dns.data().dptr<DType>(),
+                csr_data.dptr<DType>(), csr_indices.dptr<IType>(), 
csr_indptr.dptr<CType>(),
+                num_csr_rows, num_csr_cols);
+            }
+            Copy(output.aux_data(kIdx).FlatTo1D<xpu, IType>(),
+                 csr.aux_data(kIdx).FlatTo1D<xpu, IType>(), s);
+            Copy(output.aux_data(kIndPtr).FlatTo1D<xpu, CType>(),
+                 csr.aux_data(kIndPtr).FlatTo1D<xpu, CType>(), s);
+          });
+        });
+      });
+    });
+  } else {
+    FillZerosCsrImpl(s, output);
+  }
+}
+
+/*!
  * \brief Kernel for performing elemwise op between dense and rsp tensor
  * \param i            global thread id
  * \param req          type of request
diff --git a/src/operator/tensor/elemwise_binary_op.cc 
b/src/operator/tensor/elemwise_binary_op.cc
index e8ba2fa..9ccbacc 100644
--- a/src/operator/tensor/elemwise_binary_op.cc
+++ b/src/operator/tensor/elemwise_binary_op.cc
@@ -63,6 +63,11 @@ bool ElemwiseBinaryOp::BackwardUseInStorageType(const 
nnvm::NodeAttrs& attrs,
   const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask;
   const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback :
                            DispatchMode::kFComputeEx;
+  const int ograd_stype = in_attrs->at(0);
+  const int lhs_stype = in_attrs->at(1);
+  const int rhs_stype = in_attrs->at(2);
+  int& lhs_grad_stype = out_attrs->at(0);
+  int& rhs_grad_stype = out_attrs->at(1);
   if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
     dispatched = storage_type_assign(out_attrs, kDefaultStorage,
                                      dispatch_mode, DispatchMode::kFCompute);
@@ -74,6 +79,22 @@ bool ElemwiseBinaryOp::BackwardUseInStorageType(const 
nnvm::NodeAttrs& attrs,
                                        dispatch_mode, dispatch_ex);
     }
   }
+  if (!dispatched && ograd_stype == kDefaultStorage &&
+      ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+       (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
+    const bool reverse = (lhs_stype == kCSRStorage);
+    if (reverse &&
+        type_assign(&lhs_grad_stype, kDefaultStorage) &&
+        type_assign(&rhs_grad_stype, kCSRStorage)) {
+      DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+      dispatched = true;
+    } else if (!reverse &&
+               type_assign(&lhs_grad_stype, kCSRStorage) &&
+               type_assign(&rhs_grad_stype, kDefaultStorage)) {
+      DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
+      dispatched = true;
+    }
+  }
   if (!dispatched) {
     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
   }
diff --git a/src/operator/tensor/elemwise_binary_op.h 
b/src/operator/tensor/elemwise_binary_op.h
index a5b73da..ad4b3e7 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -165,12 +165,11 @@ class ElemwiseBinaryOp : public OpBase {
     typename xpu,
     typename LOP,
     typename ROP,
-    typename DType,
     bool in0_ok_dense = false,
     bool in1_ok_dense = false,
     bool in2_ok_dense = false,
     typename BackupCompute>
-  static inline void BackwardUseInEx_(const nnvm::NodeAttrs &attrs,
+  static inline void RspRspOpBackward(const nnvm::NodeAttrs &attrs,
                                       const OpContext &ctx,
                                       const std::vector<NDArray> &inputs,
                                       const std::vector<OpReqType> &req,
@@ -200,6 +199,33 @@ class ElemwiseBinaryOp : public OpBase {
     }
   }
 
+  template<typename xpu, typename LOP, typename ROP>
+  static inline void DnsCsrCsrOpBackward(const nnvm::NodeAttrs &attrs,
+                                         const OpContext &ctx,
+                                         const std::vector<NDArray> &inputs,
+                                         const std::vector<OpReqType> &req,
+                                         const std::vector<NDArray> &outputs) {
+    const bool supported_ops = std::is_same<mshadow_op::right, LOP>::value &&
+                                std::is_same<mshadow_op::left, ROP>::value;
+    CHECK(supported_ops)
+      << "Only backward for mul is supported (LOP should be right, ROP should 
be left)";
+    const NDArray& out_grad = inputs[0];
+    const NDArray& lhs_in = inputs[1];
+    const NDArray& rhs_in = inputs[2];
+    const NDArray& lhs_grad = outputs[0];
+    const NDArray& rhs_grad = outputs[1];
+    const bool reverse = (outputs[0].storage_type() == kCSRStorage);
+    if (reverse) {
+      DnsCsrCsrOp<xpu, mshadow_op::mul>(attrs, ctx, out_grad, rhs_in, req[0], 
lhs_grad, false);
+      Compute<xpu, mshadow_op::mul>(attrs, ctx, {out_grad.data(), 
lhs_in.data()}, {req[1]},
+                                    {rhs_grad.data()});
+    } else {
+      DnsCsrCsrOp<xpu, mshadow_op::mul>(attrs, ctx, out_grad, lhs_in, req[1], 
rhs_grad, false);
+      Compute<xpu, mshadow_op::mul>(attrs, ctx, {out_grad.data(), 
rhs_in.data()}, {req[0]},
+                                    {lhs_grad.data()});
+    }
+  }
+
  public:
   /*! \brief Binary op handling for lhr/rhs: RspDns, RspRsp, DnsRsp, or 
RspRsp->Dns result */
   template<typename OP>
@@ -232,44 +258,54 @@ class ElemwiseBinaryOp : public OpBase {
   /*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
   template<typename OP>
   static void CsrCsrOp(mshadow::Stream<cpu> *s,
-                              const nnvm::NodeAttrs &attrs,
-                              const OpContext &ctx,
-                              const NDArray &lhs,
-                              const NDArray &rhs,
-                              OpReqType req,
-                              const NDArray &output);
+                       const nnvm::NodeAttrs &attrs,
+                       const OpContext &ctx,
+                       const NDArray &lhs,
+                       const NDArray &rhs,
+                       OpReqType req,
+                       const NDArray &output);
 
   /*! \brief CSR -op- CSR binary operator for non-canonical NDArray */
   template<typename OP>
   static void CsrCsrOp(mshadow::Stream<gpu> *s,
-                              const nnvm::NodeAttrs &attrs,
-                              const OpContext &ctx,
-                              const NDArray &lhs,
-                              const NDArray &rhs,
-                              OpReqType req,
-                              const NDArray &output);
+                       const nnvm::NodeAttrs &attrs,
+                       const OpContext &ctx,
+                       const NDArray &lhs,
+                       const NDArray &rhs,
+                       OpReqType req,
+                       const NDArray &output);
 
   /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
   template<typename xpu, typename OP>
   static void DnsCsrDnsOp(mshadow::Stream<xpu> *s,
-                                 const nnvm::NodeAttrs &attrs,
-                                 const OpContext &ctx,
-                                 const NDArray &lhs,
-                                 const NDArray &rhs,
-                                 OpReqType req,
-                                 const NDArray &output,
-                                 const bool reverse);
+                          const nnvm::NodeAttrs &attrs,
+                          const OpContext &ctx,
+                          const NDArray &lhs,
+                          const NDArray &rhs,
+                          OpReqType req,
+                          const NDArray &output,
+                          const bool reverse);
+
+  /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
+  template<typename xpu, typename OP>
+  static void DnsCsrCsrOp(const nnvm::NodeAttrs &attrs,
+                          const OpContext &ctx,
+                          const NDArray &lhs,
+                          const NDArray &rhs,
+                          OpReqType req,
+                          const NDArray &output,
+                          const bool reverse);
 
   /*! \brief DNS -op- RSP binary operator for non-canonical NDArray */
   template<typename xpu, typename OP>
   static void DnsRspDnsOp(mshadow::Stream<xpu> *s,
-                                 const nnvm::NodeAttrs &attrs,
-                                 const OpContext &ctx,
-                                 const NDArray &lhs,
-                                 const NDArray &rhs,
-                                 OpReqType req,
-                                 const NDArray &output,
-                                 const bool reverse);
+                          const nnvm::NodeAttrs &attrs,
+                          const OpContext &ctx,
+                          const NDArray &lhs,
+                          const NDArray &rhs,
+                          OpReqType req,
+                          const NDArray &output,
+                          const bool reverse);
 
  public:
   /*!
@@ -336,6 +372,14 @@ class ElemwiseBinaryOp : public OpBase {
         dispatched = storage_type_assign(&out_stype, kRowSparseStorage,
                                          dispatch_mode, dispatch_ex);
     }
+    if (!dispatched &&
+        ((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+         (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage))) {
+        // csr, dns -> csr
+        // dns, csr -> csr
+        dispatched = storage_type_assign(&out_stype, kCSRStorage,
+                                         dispatch_mode, 
DispatchMode::kFComputeEx);
+    }
     if (!dispatched) {
       dispatched = dispatch_fallback(out_attrs, dispatch_mode);
     }
@@ -540,6 +584,14 @@ class ElemwiseBinaryOp : public OpBase {
             req[0], outputs[0], lhs_may_be_dense, rhs_may_be_dense, false, 
false);
     } else if (lhs_stype == kCSRStorage && rhs_stype == kCSRStorage) {
       ComputeEx<xpu, OP>(attrs, ctx, inputs, req, outputs);
+    } else if (((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+                (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage)) &&
+                out_stype == kCSRStorage) {
+      const NDArray& dns = (lhs_stype == kDefaultStorage)? inputs[0] : 
inputs[1];
+      const NDArray& csr = (lhs_stype == kCSRStorage)? inputs[0] : inputs[1];
+      const bool reverse = (lhs_stype == kCSRStorage);
+
+      DnsCsrCsrOp<xpu, OP>(attrs, ctx, dns, csr, req[0], outputs[0], reverse);
     } else {
       LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
     }
@@ -635,16 +687,21 @@ class ElemwiseBinaryOp : public OpBase {
     using namespace common;
     CHECK_EQ(inputs.size(), 3U);
     CHECK_EQ(outputs.size(), 2U);  // lhs input grad, rhs input grad
+    const auto out_grad_stype = inputs[0].storage_type();
     const auto lhs_grad_stype = outputs[0].storage_type();
     const auto rhs_grad_stype = outputs[1].storage_type();
     if (ContainsOnlyStorage(inputs, kRowSparseStorage) &&
         (lhs_grad_stype == kDefaultStorage || lhs_grad_stype == 
kRowSparseStorage) &&
         (rhs_grad_stype == kDefaultStorage || rhs_grad_stype == 
kRowSparseStorage)) {
       // rsp, rsp, rsp -> [dns, rsp], [dns, rsp]
-      MSHADOW_TYPE_SWITCH(outputs[0].dtype(), DType, {
-        BackwardUseInEx_<xpu, LOP, ROP, DType, in0_ok_dense, in1_ok_dense, 
in2_ok_dense>(
-          attrs, ctx, inputs, req, outputs, BackwardUseIn<xpu, LOP, ROP>);
-      });
+      RspRspOpBackward<xpu, LOP, ROP, in0_ok_dense, in1_ok_dense, 
in2_ok_dense>(
+        attrs, ctx, inputs, req, outputs, BackwardUseIn<xpu, LOP, ROP>);
+    }
+    if (((lhs_grad_stype == kDefaultStorage && rhs_grad_stype == kCSRStorage) 
||
+         (lhs_grad_stype == kCSRStorage && rhs_grad_stype == kDefaultStorage)) 
&&
+        out_grad_stype == kDefaultStorage) {
+      // dns, csr, dns -> [csr, dns] / csr, dns, dns -> [dns, csr]
+      DnsCsrCsrOpBackward<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
     }
   }
 };  // class ElemwiseBinaryOp
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu 
b/src/operator/tensor/elemwise_binary_op_basic.cu
index 9c1fd0e..5cdd894 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cu
+++ b/src/operator/tensor/elemwise_binary_op_basic.cu
@@ -51,7 +51,9 @@ NNVM_REGISTER_OP(_backward_sub)
                     mshadow_op::negation>);
 
 NNVM_REGISTER_OP(elemwise_mul)
-.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, 
op::mshadow_op::mul>);
+.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, 
op::mshadow_op::mul>)
+.set_attr<FComputeEx>("FComputeEx<gpu>",
+  ElemwiseBinaryOp::ComputeDnsLRValueEx<gpu, op::mshadow_op::mul, true, true>);
 
 NNVM_REGISTER_OP(_backward_mul)
 .set_attr<FCompute>("FCompute<gpu>",
diff --git a/tests/python/unittest/test_sparse_operator.py 
b/tests/python/unittest/test_sparse_operator.py
index 226db70..b2ff0fe 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -329,9 +329,19 @@ def test_elemwise_binary_ops():
             return 'row_sparse'
         elif lstype == 'row_sparse' and rstype == 'default':
             return 'row_sparse'
+        elif lstype == 'default' and rstype == 'csr':
+            return 'csr'
+        elif lstype == 'csr' and rstype == 'default':
+            return 'csr'
         else:
             return 'default'
 
+    def elemwise_mul_lhs_grad_stype(lstype, rstype):
+        return elemwise_mul_stype(elemwise_mul_stype(lstype, rstype), rstype)
+
+    def elemwise_mul_rhs_grad_stype(lstype, rstype):
+        return elemwise_mul_stype(elemwise_mul_stype(lstype, rstype), lstype)
+
     def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape,
                                   lhs_grad_stype=None, rhs_grad_stype=None,
                                   lhs_density=.5, rhs_density=.5,
@@ -378,8 +388,8 @@ def test_elemwise_binary_ops():
                                 lambda l, r: mx.sym.sparse.elemwise_mul(l, r),
                                 lambda l, r: l * r,
                                 lambda outg, l, r: (outg * r, outg * l),
-                                elemwise_mul_stype(lhs_stype, rhs_stype),
-                                elemwise_mul_stype(lhs_stype, rhs_stype),
+                                elemwise_mul_lhs_grad_stype(lhs_stype, 
rhs_stype),
+                                elemwise_mul_rhs_grad_stype(lhs_stype, 
rhs_stype),
                                 
expected_result_storage_type=elemwise_mul_stype(lhs_stype, rhs_stype),
                                 ograd_density=ograd_density,
                                 force_lr_overlap=force_lr_overlap,

-- 
To stop receiving notification emails like this one, please contact
hai...@apache.org.

Reply via email to