eric-haibin-lin closed pull request #10550: [MXNET-320] Support 
elemwise_add/sub between dense and csr tensors
URL: https://github.com/apache/incubator-mxnet/pull/10550
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/tensor/elemwise_binary_op-inl.h 
b/src/operator/tensor/elemwise_binary_op-inl.h
index 15b1f0e286e..54b7aa60a64 100644
--- a/src/operator/tensor/elemwise_binary_op-inl.h
+++ b/src/operator/tensor/elemwise_binary_op-inl.h
@@ -374,6 +374,82 @@ void ElemwiseBinaryOp::CsrCsrOp(mshadow::Stream<cpu> *s,
   }
 }
 
+/*!
+ * \brief Kernel for performing elemwise op between dense and csr matrix
+ * \param i            global thread id
+ * \param req          type of request
+ * \param out          output array
+ * \param dns_data     data array of dense input
+ * \param csr_data     data array of csr input
+ * \param csr_indices  indices array of csr input
+ * \param csr_indptr   indptr array of csr input
+ * \param num_rows     number of rows of both inputs
+ * \param num_cols     number of columns of both inputs
+ */
+template<typename OP>
+struct ElemwiseDnsCsrDnsKernel {
+  template<typename DType, typename IType, typename CType>
+  static void inline Map(int i, OpReqType req, DType* out, DType* dns_data, 
const DType* csr_data,
+                         const IType* csr_indices, const CType* csr_indptr,
+                         const nnvm::dim_t num_rows, const nnvm::dim_t 
num_cols) {
+    if (i < num_rows) {
+      for (int j = csr_indptr[i]; j < csr_indptr[i+1]; ++j) {
+        KERNEL_ASSIGN(out[i * num_cols + csr_indices[j]], req,
+                      OP::Map(dns_data[i * num_cols + csr_indices[j]], 
csr_data[j]));
+      }
+    }
+  }
+};
+
+/*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
+template<typename OP>
+void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<cpu> *s,
+                                   const nnvm::NodeAttrs &attrs,
+                                   const OpContext &ctx,
+                                   const NDArray &dns,
+                                   const NDArray &csr,
+                                   const OpReqType req,
+                                   const NDArray &output,
+                                   const bool reverse) {
+  using namespace mshadow;
+  using namespace mxnet_op;
+  CHECK_EQ(dns.storage_type(), kDefaultStorage);
+  CHECK_EQ(csr.storage_type(), kCSRStorage);
+  CHECK(req != kAddTo);
+  CHECK(req != kNullOp);
+  const bool supported_op = std::is_same<OP, mshadow_op::minus>::value ||
+                            std::is_same<OP, mshadow_op::plus>::value;
+  CHECK(supported_op == true);
+  const nnvm::dim_t num_csr_rows = csr.shape()[0];
+  const nnvm::dim_t num_csr_cols = csr.shape()[1];
+  TBlob csr_data = csr.data();
+  TBlob csr_indices = csr.aux_data(csr::kIdx);
+  TBlob csr_indptr = csr.aux_data(csr::kIndPtr);
+  MSHADOW_SGL_DBL_TYPE_SWITCH(csr_data.type_flag_, DType, {
+    MSHADOW_IDX_TYPE_SWITCH(csr_indices.type_flag_, IType, {
+      MSHADOW_IDX_TYPE_SWITCH(csr_indptr.type_flag_, CType, {
+        MXNET_ASSIGN_REQ_SWITCH(req, Req, {
+          if (reverse && std::is_same<OP, mshadow_op::minus>::value) {
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::negation, Req>, 
cpu>::Launch(
+              s, output.data().Size(), output.data().dptr<DType>(), 
dns.data().dptr<DType>());
+            mxnet_op::Kernel<ElemwiseDnsCsrDnsKernel<mshadow_op::plus>, 
cpu>::Launch(
+              s, num_csr_rows, Req, output.data().dptr<DType>(),
+              output.data().dptr<DType>(), csr_data.dptr<DType>(), 
csr_indices.dptr<IType>(),
+              csr_indptr.dptr<CType>(), num_csr_rows, num_csr_cols);
+          } else {
+            mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::identity, Req>, 
cpu>::Launch(
+              s, output.data().Size(), output.data().dptr<DType>(), 
dns.data().dptr<DType>());
+            mxnet_op::Kernel<ElemwiseDnsCsrDnsKernel<OP>, cpu>::Launch(
+              s, num_csr_rows, Req, output.data().dptr<DType>(),
+              output.data().dptr<DType>(), csr_data.dptr<DType>(), 
csr_indices.dptr<IType>(),
+              csr_indptr.dptr<CType>(), num_csr_rows, num_csr_cols);
+          }
+        });
+      });
+    });
+  });
+}
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/tensor/elemwise_binary_op.h 
b/src/operator/tensor/elemwise_binary_op.h
index 9a151d38f81..b3b36ab83f4 100644
--- a/src/operator/tensor/elemwise_binary_op.h
+++ b/src/operator/tensor/elemwise_binary_op.h
@@ -233,6 +233,17 @@ class ElemwiseBinaryOp : public OpBase {
                               OpReqType req,
                               const NDArray &output);
 
+  /*! \brief DNS -op- CSR binary operator for non-canonical NDArray */
+  template<typename OP>
+  static inline void DnsCsrDnsOp(mshadow::Stream<cpu> *s,
+                                 const nnvm::NodeAttrs &attrs,
+                                 const OpContext &ctx,
+                                 const NDArray &lhs,
+                                 const NDArray &rhs,
+                                 OpReqType req,
+                                 const NDArray &output,
+                                 const bool reverse);
+
  public:
   /*!
    * \brief Rsp-op-Rsp operation which produces a dense result
@@ -305,6 +316,60 @@ class ElemwiseBinaryOp : public OpBase {
     return dispatched;
   }
 
+
+  /*!
+   * \brief Allow one of the inputs to be dense and produce a dense output,
+   *        for rsp inputs only support when both inputs are rsp type.
+   * \param attrs Attributes
+   * \param dev_mask Device mask
+   * \param dispatch_mode Dispatch Mode
+   * \param in_attrs Input storage attributes
+   * \param out_attrs Output storage attributes
+   * \return true if handled
+   */
+  template<bool cpu_only, bool rsp, bool csr>
+  static bool PreferDenseStorageType(const nnvm::NodeAttrs& attrs,
+                                     const int dev_mask,
+                                     DispatchMode* dispatch_mode,
+                                     std::vector<int> *in_attrs,
+                                     std::vector<int> *out_attrs) {
+    using namespace common;
+    CHECK_EQ(in_attrs->size(), 2);
+    CHECK_EQ(out_attrs->size(), 1);
+    const auto lhs_stype = (*in_attrs)[0];
+    const auto rhs_stype = (*in_attrs)[1];
+    bool dispatched = false;
+    const bool invalid_ctx = cpu_only && dev_mask != mshadow::cpu::kDevMask;
+    const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback :
+                                           DispatchMode::kFComputeEx;
+    if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
+      // dns, dns ... -> dns
+      dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+                                       dispatch_mode, DispatchMode::kFCompute);
+    }
+    if (!dispatched && rsp && ContainsOnlyStorage(*in_attrs, 
kRowSparseStorage)) {
+      // rsp, rsp, ... -> rsp
+      dispatched = storage_type_assign(out_attrs, kRowSparseStorage,
+                                       dispatch_mode, dispatch_ex);
+    }
+    if (!dispatched && csr && ContainsOnlyStorage(*in_attrs, kCSRStorage)) {
+      // csr, csr, ... -> csr
+      dispatched = storage_type_assign(out_attrs, kCSRStorage,
+                                       dispatch_mode, dispatch_ex);
+    }
+    if (!dispatched && ((lhs_stype == kDefaultStorage && rhs_stype == 
kCSRStorage) ||
+                        (lhs_stype == kCSRStorage && rhs_stype == 
kDefaultStorage))) {
+      // dense, csr -> dense / csr, dense -> dense
+      dispatched = storage_type_assign(out_attrs, kDefaultStorage,
+                                       dispatch_mode, dispatch_ex);
+    }
+    if (!dispatched) {
+      dispatch_fallback(out_attrs, dispatch_mode);
+    }
+    return true;
+  }
+
+
   /*!
    * \brief Backward pass computing input gradient using forward inputs
    * \param attrs Attributes
@@ -376,6 +441,7 @@ class ElemwiseBinaryOp : public OpBase {
     CHECK_EQ(outputs.size(), 1);
     if (req[0] == kNullOp) return;
     const auto lhs_stype = inputs[0].storage_type();
+    const auto rhs_stype = inputs[1].storage_type();
     const auto out_stype = outputs[0].storage_type();
     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
     if ((common::ContainsOnlyStorage(inputs, kRowSparseStorage))
@@ -399,6 +465,14 @@ class ElemwiseBinaryOp : public OpBase {
           });
         });
       });
+    } else if (((lhs_stype == kCSRStorage && rhs_stype == kDefaultStorage) ||
+                (lhs_stype == kDefaultStorage && rhs_stype == kCSRStorage)) &&
+                out_stype == kDefaultStorage) {
+      const NDArray& dns = (lhs_stype == kDefaultStorage)? inputs[0] : 
inputs[1];
+      const NDArray& csr = (lhs_stype == kCSRStorage)? inputs[0] : inputs[1];
+      const bool reverse = (lhs_stype == kCSRStorage);
+
+      DnsCsrDnsOp<OP>(s, attrs, ctx, dns, csr, req[0], outputs[0], reverse);
     } else {
       LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
     }
@@ -590,6 +664,16 @@ class ElemwiseBinaryOp : public OpBase {
   .set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, 
__kernel$>)              \
   .set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseBinaryOp::ComputeEx<cpu, 
__kernel$>)
 
+/*! \brief Binary launch, with FComputeEx for prefer dense */
+#define MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_PD(__name$, __kernel$)  
            \
+  MXNET_OPERATOR_REGISTER_BINARY(__name$)                                      
         \
+  .set_attr<FInferStorageType>("FInferStorageType",                            
         \
+    ElemwiseBinaryOp::PreferDenseStorageType<true, true, true>)                
         \
+  .set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, 
__kernel$>)       \
+  .set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseBinaryOp::ComputeEx<cpu, 
__kernel$>) \
+  .set_attr<FResourceRequest>("FResourceRequest",  /* For Sparse CSR */ \
+    [](const NodeAttrs& attrs) { \
+      return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};})
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc 
b/src/operator/tensor/elemwise_binary_op_basic.cc
index d73edc72352..3f2892fed17 100644
--- a/src/operator/tensor/elemwise_binary_op_basic.cc
+++ b/src/operator/tensor/elemwise_binary_op_basic.cc
@@ -67,8 +67,8 @@ static inline bool ElemwiseAddStorageType(const 
nnvm::NodeAttrs& attrs,
                                           std::vector<int> *out_attrs) {
   CHECK_EQ(in_attrs->size(), 2);
   CHECK_EQ(out_attrs->size(), 1);
-  bool ret = ElemwiseStorageType<2, 1, true, true, true>(attrs, dev_mask, 
dispatch_mode,
-                                                         in_attrs, out_attrs);
+  bool ret = ElemwiseBinaryOp::PreferDenseStorageType<true, true, true>(
+               attrs, dev_mask, dispatch_mode, in_attrs, out_attrs);
 #if MXNET_USE_MKLDNN == 1
   if (dev_mask == mshadow::cpu::kDevMask
       && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
@@ -94,6 +94,8 @@ The storage type of ``elemwise_add`` output depends on 
storage types of inputs
 
    - elemwise_add(row_sparse, row_sparse) = row_sparse
    - elemwise_add(csr, csr) = csr
+   - elemwise_add(default, csr) = default
+   - elemwise_add(csr, default) = default
    - otherwise, ``elemwise_add`` generates output with default storage
 
 )code")
@@ -157,7 +159,7 @@ NNVM_REGISTER_OP(_backward_add)
 .set_attr<FComputeEx>("FComputeEx<cpu>", _backward_ElemwiseAddEx)
 .set_attr<FInferStorageType>("FInferStorageType", 
ElemwiseAddBackwardStorageType);
 
-MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(elemwise_sub, 
op::mshadow_op::minus)
+MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_PD(elemwise_sub, 
op::mshadow_op::minus)
 MXNET_ADD_SPARSE_OP_ALIAS(elemwise_sub)
 .add_alias("_sub").add_alias("_minus").add_alias("_Minus")
 .describe(R"code(Subtracts arguments element-wise.
@@ -166,6 +168,8 @@ The storage type of ``elemwise_sub`` output depends on 
storage types of inputs
 
    - elemwise_sub(row_sparse, row_sparse) = row_sparse
    - elemwise_sub(csr, csr) = csr
+   - elemwise_sub(default, csr) = default
+   - elemwise_sub(csr, default) = default
    - otherwise, ``elemwise_sub`` generates output with default storage
 
 )code")


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to