eric-haibin-lin closed pull request #10557: [MXNET-322] Elemwise_mul(row_sparse, dense) = row_sparse on CPU URL: https://github.com/apache/incubator-mxnet/pull/10557
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/src/common/utils.h b/src/common/utils.h index 4f84a54e503..f0ef94097bb 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -432,6 +432,7 @@ inline void LogStorageFallback(const nnvm::NodeAttrs& attrs, "for execution. You're seeing this warning message because the operator above is unable " "to process the given ndarrays with specified storage types, context and parameter. " "Temporary dense ndarrays are generated in order to execute the operator. " + "This does not affect the correctness of the programme. " "You can set environment variable MXNET_STORAGE_FALLBACK_LOG_VERBOSE to " "0 to suppress this warning."; os << "\nStorage type fallback detected:\n" << op_str << warning; diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index d175a13632a..1ef7759b4da 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -1192,8 +1192,9 @@ void CopyFromTo(const NDArray& from, const NDArray& to, int priority) { << " to " << stype_string(to_stype) << " storage type on " << dev_type_string(b) << ".\nA temporary ndarray with " << stype_string(to_stype) << " storage type will be generated in order to perform the copy. " - << "You can set environment variable " - << "MXNET_STORAGE_FALLBACK_LOG_VERBOSE to 0 to suppress this warning."; + "This does not affect the correctness of the programme. " + "You can set environment variable " + "MXNET_STORAGE_FALLBACK_LOG_VERBOSE to 0 to suppress this warning."; LogOnce(os.str()); } diff --git a/src/operator/tensor/elemwise_binary_op.h b/src/operator/tensor/elemwise_binary_op.h index 9a151d38f81..88360791cee 100644 --- a/src/operator/tensor/elemwise_binary_op.h +++ b/src/operator/tensor/elemwise_binary_op.h @@ -250,7 +250,9 @@ class ElemwiseBinaryOp : public OpBase { std::vector<int> *out_attrs); /*! - * \brief Allow one of the inputs to be dense and still produce a sparse output + * \brief Allow one of the binary inputs to be dense and still produce a sparse output. + * Typically used for sparse * dense = sparse. + * Note: for csr, it dispatches to fallback other than csr, csr -> csr * \param attrs Attributes * \param dev_mask Device mask * \param dispatch_mode Dispatch Mode @@ -258,12 +260,12 @@ class ElemwiseBinaryOp : public OpBase { * \param out_attrs Output storage attributes * \return true if handled */ - template<bool lhs_dense_ok = true, bool rhs_dense_ok = true> - static bool AllowLRDenseInputWithSparseOutputStorageType(const nnvm::NodeAttrs& attrs, - int dev_mask, - DispatchMode* dispatch_mode, - std::vector<int> *in_attrs, - std::vector<int> *out_attrs) { + static bool PreferSparseStorageType(const nnvm::NodeAttrs& attrs, + int dev_mask, + DispatchMode* dispatch_mode, + std::vector<int> *in_attrs, + std::vector<int> *out_attrs) { + using namespace common; CHECK_EQ(in_attrs->size(), 2U) << " in operator " << attrs.name; CHECK_EQ(out_attrs->size(), 1U) << " in operator " << attrs.name; const auto& lhs_stype = in_attrs->at(0); @@ -273,31 +275,28 @@ class ElemwiseBinaryOp : public OpBase { const bool invalid_ctx = dev_mask != mshadow::cpu::kDevMask; const auto dispatch_ex = invalid_ctx ? DispatchMode::kFComputeFallback : DispatchMode::kFComputeEx; - if (!dispatched && lhs_stype == kDefaultStorage && rhs_stype == kDefaultStorage) { + if (!dispatched && ContainsOnlyStorage(*in_attrs, kDefaultStorage)) { // dns, dns -> dns dispatched = storage_type_assign(&out_stype, kDefaultStorage, dispatch_mode, DispatchMode::kFCompute); } - if (!dispatched) { - if ((lhs_stype == kRowSparseStorage && rhs_stype == kRowSparseStorage) || - (rhs_dense_ok && lhs_stype == kRowSparseStorage && rhs_stype == kDefaultStorage) || - (lhs_dense_ok && lhs_stype == kDefaultStorage && rhs_stype == kRowSparseStorage)) { + if (!dispatched && ContainsOnlyStorage(*in_attrs, kRowSparseStorage)) { // rsp, rsp -> rsp - // rsp, dns -> rsp - // dns, rsp -> rsp dispatched = storage_type_assign(&out_stype, kRowSparseStorage, dispatch_mode, dispatch_ex); - } else if (lhs_stype == kCSRStorage && rhs_stype == kCSRStorage) { + } + if (!dispatched && ContainsOnlyStorage(*in_attrs, kCSRStorage)) { // csr, csr -> csr dispatched = storage_type_assign(&out_stype, kCSRStorage, dispatch_mode, dispatch_ex); - } else if ((lhs_stype == kCSRStorage && rhs_dense_ok) || - (rhs_stype == kCSRStorage && lhs_dense_ok)) { - // csr, dns -> csr - // dns, csr -> csr - dispatched = storage_type_assign(&out_stype, kCSRStorage, - dispatch_mode, DispatchMode::kFComputeFallback); - } + } + if (!dispatched && + ((lhs_stype == kRowSparseStorage && rhs_stype == kDefaultStorage) || + (lhs_stype == kDefaultStorage && rhs_stype == kRowSparseStorage))) { + // rsp, dns -> rsp + // dns, rsp -> rsp + dispatched = storage_type_assign(&out_stype, kRowSparseStorage, + dispatch_mode, dispatch_ex); } if (!dispatched) { dispatched = dispatch_fallback(out_attrs, dispatch_mode); @@ -372,14 +371,15 @@ class ElemwiseBinaryOp : public OpBase { const std::vector<NDArray> &inputs, const std::vector<OpReqType> &req, const std::vector<NDArray> &outputs) { + using namespace common; CHECK_EQ(inputs.size(), 2); CHECK_EQ(outputs.size(), 1); if (req[0] == kNullOp) return; const auto lhs_stype = inputs[0].storage_type(); const auto out_stype = outputs[0].storage_type(); mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); - if ((common::ContainsOnlyStorage(inputs, kRowSparseStorage)) - && (out_stype == kRowSparseStorage || out_stype == kDefaultStorage)) { + if ((ContainsOnlyStorage(inputs, kRowSparseStorage)) && + (out_stype == kRowSparseStorage || out_stype == kDefaultStorage)) { // rsp, rsp -> rsp // rsp, rsp -> dns const int rsp_input_idx = lhs_stype == kRowSparseStorage ? 0 : 1; @@ -389,7 +389,7 @@ class ElemwiseBinaryOp : public OpBase { s, attrs, ctx, inputs[0], inputs[1], req[0], outputs[0], false, false, false, false); }); }); - } else if (common::ContainsOnlyStorage(inputs, kCSRStorage) && out_stype == kCSRStorage) { + } else if (ContainsOnlyStorage(inputs, kCSRStorage) && out_stype == kCSRStorage) { // csr, csr -> csr MSHADOW_IDX_TYPE_SWITCH(inputs[0].aux_type(csr::kIdx), IType, { MSHADOW_IDX_TYPE_SWITCH(inputs[0].aux_type(csr::kIndPtr), CType, { @@ -579,6 +579,19 @@ class ElemwiseBinaryOp : public OpBase { [](const NodeAttrs& attrs) { \ return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};}) +/*! \brief Binary launch, with FComputeEx for csr and rsp available. + when inputs contain both sparse and dense, sparse output is preferred. */ +#define MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU_PS(__name$, __kernel$) \ + MXNET_OPERATOR_REGISTER_BINARY(__name$) \ + .set_attr<FInferStorageType>("FInferStorageType", \ + ElemwiseBinaryOp::PreferSparseStorageType) \ + .set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, __kernel$>) \ + .set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseBinaryOp::ComputeEx<cpu, __kernel$>) \ + .set_attr<FResourceRequest>("FResourceRequest", /* For Sparse CSR */ \ + [](const NodeAttrs& attrs) { \ + return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};}) + + /*! \brief Binary launch, dense result * FInferStorageType attr is not set using this macro. * By default DefaultStorageType is used. diff --git a/src/operator/tensor/elemwise_binary_op_basic.cc b/src/operator/tensor/elemwise_binary_op_basic.cc index d73edc72352..96107f87867 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cc +++ b/src/operator/tensor/elemwise_binary_op_basic.cc @@ -195,15 +195,14 @@ The storage type of ``elemwise_mul`` output depends on storage types of inputs - elemwise_mul(default, default) = default - elemwise_mul(row_sparse, row_sparse) = row_sparse - - elemwise_mul(default, row_sparse) = default - - elemwise_mul(row_sparse, default) = default + - elemwise_mul(default, row_sparse) = row_sparse + - elemwise_mul(row_sparse, default) = row_sparse - elemwise_mul(csr, csr) = csr - otherwise, ``elemwise_mul`` generates output with default storage )code") .set_attr<FInferStorageType>("FInferStorageType", - ElemwiseBinaryOp::AllowLRDenseInputWithSparseOutputStorageType< - false, false>) // 0 * nan or nan * 0 -> nan, so rsp * dns -> dns + ElemwiseBinaryOp::PreferSparseStorageType) .set_attr<FCompute>("FCompute<cpu>", ElemwiseBinaryOp::Compute<cpu, op::mshadow_op::mul>) .set_attr<FComputeEx>("FComputeEx<cpu>", ElemwiseBinaryOp::ComputeDnsLRValueEx<cpu, op::mshadow_op::mul, true, true>) diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 0c4e37af107..37710843544 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -29,6 +29,7 @@ #include <vector> #include <utility> #include <algorithm> +#include "./cast_storage-inl.h" #include "../mshadow_op.h" #include "../mxnet_op.h" #include "../elemwise_op_common.h" @@ -328,15 +329,17 @@ class UnaryOp : public OpBase { const std::vector<NDArray>& inputs, const std::vector<OpReqType>& req, const std::vector<NDArray>& outputs) { - using namespace mshadow; - using namespace mshadow::expr; CHECK_EQ(inputs.size(), 2); CHECK_EQ(outputs.size(), 1); const auto lhs_stype = inputs[0].storage_type(); const auto out_stype = outputs[0].storage_type(); - if (lhs_stype == out_stype && (lhs_stype == kRowSparseStorage || lhs_stype == kCSRStorage)) { + bool supported_stype = lhs_stype == kRowSparseStorage || lhs_stype == kCSRStorage; + if (supported_stype && lhs_stype == out_stype) { // csr, _ -> csr, or rsp, _ -> rsp OpBase::CopyNDArray(ctx.get_stream<xpu>(), &outputs[0], req[0], inputs[0]); + } else if (supported_stype && out_stype == kDefaultStorage) { + // csr/rsp, _ -> dns + CastStorageComputeImpl<xpu>(ctx, inputs[0], outputs[0]); } else { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index acd8f7b23ff..e711148898f 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -37,8 +37,8 @@ static bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, std::vector<int> *out_attrs) { CHECK_EQ(in_attrs->size(), 2U); CHECK_EQ(out_attrs->size(), 1U); - auto& lhs_stype = in_attrs->at(0); const auto& rhs_stype = in_attrs->at(1); + auto& lhs_stype = in_attrs->at(0); auto& out_stype = out_attrs->at(0); bool dispatched = false; @@ -57,9 +57,10 @@ static bool IdentityAttrLikeRhsStorageType(const nnvm::NodeAttrs& attrs, dispatched = storage_type_assign(&out_stype, static_cast<NDArrayStorageType>(out_stype), dispatch_mode, DispatchMode::kFComputeEx); } - if (!dispatched && (rhs_stype == kRowSparseStorage || rhs_stype == kCSRStorage)) { - // rsp, _ -> rsp, or csr, _ -> csr - dispatched = storage_type_assign(&out_stype, static_cast<NDArrayStorageType>(rhs_stype), + if (!dispatched && (lhs_stype == kRowSparseStorage || lhs_stype == kCSRStorage) && + (out_stype == kDefaultStorage)) { + // rsp/csr, _ -> dns + dispatched = storage_type_assign(&out_stype, static_cast<NDArrayStorageType>(out_stype), dispatch_mode, DispatchMode::kFComputeEx); } if (!dispatched) { @@ -294,6 +295,7 @@ The storage type of ``make_loss`` output depends upon the input storage type: // identity output as first input, but attributes (shape and type) are constrained to be like rhs // storage type attribute is not constrained to be like rhs if it is already defined +// for internal use only NNVM_REGISTER_OP(_identity_with_attr_like_rhs) .set_num_inputs(2) .set_attr<nnvm::FListInputNames>("FListInputNames", diff --git a/src/operator/tensor/init_op.cc b/src/operator/tensor/init_op.cc index d5a328efc23..52cb9f20ea7 100644 --- a/src/operator/tensor/init_op.cc +++ b/src/operator/tensor/init_op.cc @@ -88,7 +88,7 @@ NNVM_REGISTER_OP(_arange) NNVM_REGISTER_OP(zeros_like) .add_alias("_sparse_zeros_like") -.describe(R"code(Return an array of zeros with the same shape and type +.describe(R"code(Return an array of zeros with the same shape, type and storage type as the input array. The storage type of ``zeros_like`` output depends on the storage type of the input diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 34794866546..ec8af832f06 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -322,17 +322,15 @@ def lt(l, r): def le(l, r): return check_all(l, r, lambda a, b: a <= b) - def least_sparse(lstype, rstype): - if lstype == 'default' and rstype == 'default': + def elemwise_mul_stype(lstype, rstype): + if lstype == rstype: + return lstype + elif lstype == 'default' and rstype == 'row_sparse': + return 'row_sparse' + elif lstype == 'row_sparse' and rstype == 'default': + return 'row_sparse' + else: return 'default' - elif rstype != 'default': - return rstype - return lstype - - def most_dense(lstype, rstype): - if lstype == rstype: - return lstype - return 'default' def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape, lhs_grad_stype=None, rhs_grad_stype=None, @@ -367,9 +365,9 @@ def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape, lambda l, r: mx.sym.sparse.elemwise_mul(l, r), lambda l, r: l * r, lambda outg, l, r: (outg * r, outg * l), - least_sparse(lhs_stype, rhs_stype), - least_sparse(lhs_stype, rhs_stype), - expected_result_storage_type=most_dense(lhs_stype, rhs_stype), + elemwise_mul_stype(lhs_stype, rhs_stype), + elemwise_mul_stype(lhs_stype, rhs_stype), + expected_result_storage_type=elemwise_mul_stype(lhs_stype, rhs_stype), ograd_density=ograd_density, force_lr_overlap=force_lr_overlap, force_grad_overlap=force_grad_overlap, @@ -442,10 +440,10 @@ def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape, check_elemwise_binary_ops('default', 'default', rand_shape_2d()) # Try different densities + shape = rand_shape_2d() for lhs_density in [0.0, random.uniform(0, 1), 1.0]: for rhs_density in [0.0, random.uniform(0, 1), 1.0]: for ograd_density in [0.0, random.uniform(0, 1), 1.0]: - shape = rand_shape_2d() print("lhs_density={}, rhs_density={}, ograd_density={}, shape: {}".format( lhs_density, rhs_density, ograd_density, shape)) @@ -454,8 +452,6 @@ def check_elemwise_binary_ops(lhs_stype, rhs_stype, shape, for force_lr_overlap in [False, True]: for force_grad_overlap in [False, True]: - shape = rand_shape_2d() - print(" force_lr_overlap={}, force_grad_overlap={}, shape={}". format(force_lr_overlap, force_grad_overlap, shape)) ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services