This is an automated email from the ASF dual-hosted git repository. haibin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 2cd09a0 Add support for cast storage on same stypes (#10400) 2cd09a0 is described below commit 2cd09a0c27b6dc73fd50c1ac4eb51df6e493eb9a Author: Anirudh Subramanian <anirudh2...@gmail.com> AuthorDate: Wed Apr 4 20:37:40 2018 -0700 Add support for cast storage on same stypes (#10400) * Add cast storage support for same stypes * Add imports * Fix cast * Fix doc for cast_storage * Fix --- src/operator/tensor/cast_storage-inl.h | 51 +++++++++++++++++++++++++++ src/operator/tensor/cast_storage.cc | 2 ++ tests/python/unittest/test_sparse_operator.py | 3 ++ 3 files changed, 56 insertions(+) diff --git a/src/operator/tensor/cast_storage-inl.h b/src/operator/tensor/cast_storage-inl.h index 46de10a..f905bf8 100644 --- a/src/operator/tensor/cast_storage-inl.h +++ b/src/operator/tensor/cast_storage-inl.h @@ -30,6 +30,7 @@ #include <algorithm> #include "../mxnet_op.h" #include "../operator_common.h" +#include "../../src/operator/tensor/init_op.h" #ifdef __CUDACC__ #include "./cast_storage-inl.cuh" #endif // __CUDACC__ @@ -328,6 +329,50 @@ void CastStorageCsrDnsImpl(const OpContext& ctx, }); } +/*! + * \brief Casts a csr matrix to another csr. + */ +template <typename xpu> +void CastStorageCsrCsrImpl(const OpContext& ctx, const NDArray& csr, + NDArray* output) { + mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); + if (!csr.storage_initialized()) { + FillZerosCsrImpl(s, *output); + return; + } + std::vector<TShape> aux_shapes({csr.aux_shape(csr::kIndPtr), csr.aux_shape(csr::kIdx)}); + output->CheckAndAlloc(aux_shapes); + const TBlob& val = output->data(); + const TBlob& indptr = output->aux_data(csr::kIndPtr); + const TBlob& idx = output->aux_data(csr::kIdx); + mxnet_op::copy(s, val, csr.data()); + mxnet_op::copy(s, indptr, csr.aux_data(csr::kIndPtr)); + mxnet_op::copy(s, idx, csr.aux_data(csr::kIdx)); +} + +/*! + * \brief Casts a rsp matrix to another rsp. + */ +template <typename xpu> +void CastStorageRspRspImpl(const OpContext& ctx, const NDArray& rsp, + NDArray* output) { + CHECK_EQ(rsp.storage_type(), output->storage_type()) + << "Copying with different storage type"; + mshadow::Stream<xpu>* s = ctx.get_stream<xpu>(); + if (!rsp.storage_initialized()) { + FillZerosRspImpl(s, *output); + return; + } + auto aux_shape = rsp.aux_shape(rowsparse::kIdx); + output->CheckAndAlloc({aux_shape}); + const TBlob& val = output->data(); + const TBlob& idx = output->aux_data(rowsparse::kIdx); + const TBlob& from_val = rsp.data(); + const TBlob& from_idx = rsp.aux_data(rowsparse::kIdx); + mxnet_op::copy(s, val, from_val); + mxnet_op::copy(s, idx, from_idx); +} + template<typename xpu> void CastStorageComputeImpl(const OpContext& ctx, const NDArray& input, @@ -346,6 +391,12 @@ void CastStorageComputeImpl(const OpContext& ctx, } else if (src_stype == kCSRStorage && dst_stype == kDefaultStorage) { TBlob ret = output.data(); CastStorageCsrDnsImpl<xpu>(ctx, input, &ret); + } else if (src_stype == kCSRStorage && dst_stype == kCSRStorage) { + NDArray ret = output; + CastStorageCsrCsrImpl<xpu>(ctx, input, &ret); + } else if (src_stype == kRowSparseStorage && dst_stype == kRowSparseStorage) { + NDArray ret = output; + CastStorageRspRspImpl<xpu>(ctx, input, &ret); #if MXNET_USE_MKLDNN == 1 } else if (src_stype == kDefaultStorage && dst_stype == kDefaultStorage) { CHECK_EQ(output.ctx().dev_type, input.ctx().dev_type); diff --git a/src/operator/tensor/cast_storage.cc b/src/operator/tensor/cast_storage.cc index 9f257b1..f77a50a 100644 --- a/src/operator/tensor/cast_storage.cc +++ b/src/operator/tensor/cast_storage.cc @@ -46,6 +46,8 @@ The storage type of ``cast_storage`` output depends on stype parameter: - cast_storage(row_sparse, 'default') = default - cast_storage(default, 'csr') = csr - cast_storage(default, 'row_sparse') = row_sparse +- cast_storage(csr, 'csr') = csr +- cast_storage(row_sparse, 'row_sparse') = row_sparse Example:: diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index 9417df3..5ad5215 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -1177,10 +1177,13 @@ def test_cast_storage_ex(): shape_3d = rand_shape_3d() check_cast_storage(shape_2d, d, 'csr', 'default') check_cast_storage(shape_2d, d, 'default', 'csr') + check_cast_storage(shape_2d, d, 'csr', 'csr') check_cast_storage(shape_2d, d, 'row_sparse', 'default') check_cast_storage(shape_2d, d, 'default', 'row_sparse') + check_cast_storage(shape_2d, d, 'row_sparse', 'row_sparse') check_cast_storage(shape_3d, d, 'row_sparse', 'default') check_cast_storage(shape_3d, d, 'default', 'row_sparse') + check_cast_storage(shape_3d, d, 'row_sparse', 'row_sparse') for i in range(4, 6): shape = rand_shape_nd(i, 5) check_cast_storage(shape, d, 'default', 'row_sparse') -- To stop receiving notification emails like this one, please contact hai...@apache.org.