This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 2cd09a0  Add support for cast storage on same stypes (#10400)
2cd09a0 is described below

commit 2cd09a0c27b6dc73fd50c1ac4eb51df6e493eb9a
Author: Anirudh Subramanian <anirudh2...@gmail.com>
AuthorDate: Wed Apr 4 20:37:40 2018 -0700

    Add support for cast storage on same stypes (#10400)
    
    * Add cast storage support for same stypes
    
    * Add imports
    
    * Fix cast
    
    * Fix doc for cast_storage
    
    * Fix
---
 src/operator/tensor/cast_storage-inl.h        | 51 +++++++++++++++++++++++++++
 src/operator/tensor/cast_storage.cc           |  2 ++
 tests/python/unittest/test_sparse_operator.py |  3 ++
 3 files changed, 56 insertions(+)

diff --git a/src/operator/tensor/cast_storage-inl.h 
b/src/operator/tensor/cast_storage-inl.h
index 46de10a..f905bf8 100644
--- a/src/operator/tensor/cast_storage-inl.h
+++ b/src/operator/tensor/cast_storage-inl.h
@@ -30,6 +30,7 @@
 #include <algorithm>
 #include "../mxnet_op.h"
 #include "../operator_common.h"
+#include "../../src/operator/tensor/init_op.h"
 #ifdef __CUDACC__
 #include "./cast_storage-inl.cuh"
 #endif  // __CUDACC__
@@ -328,6 +329,50 @@ void CastStorageCsrDnsImpl(const OpContext& ctx,
   });
 }
 
+/*!
+ * \brief Casts a csr matrix to another csr.
+ */
+template <typename xpu>
+void CastStorageCsrCsrImpl(const OpContext& ctx, const NDArray& csr,
+                           NDArray* output) {
+  mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
+  if (!csr.storage_initialized()) {
+    FillZerosCsrImpl(s, *output);
+    return;
+  }
+  std::vector<TShape> aux_shapes({csr.aux_shape(csr::kIndPtr), 
csr.aux_shape(csr::kIdx)});
+  output->CheckAndAlloc(aux_shapes);
+  const TBlob& val = output->data();
+  const TBlob& indptr = output->aux_data(csr::kIndPtr);
+  const TBlob& idx = output->aux_data(csr::kIdx);
+  mxnet_op::copy(s, val, csr.data());
+  mxnet_op::copy(s, indptr, csr.aux_data(csr::kIndPtr));
+  mxnet_op::copy(s, idx, csr.aux_data(csr::kIdx));
+}
+
+/*!
+ * \brief Casts a rsp matrix to another rsp.
+ */
+template <typename xpu>
+void CastStorageRspRspImpl(const OpContext& ctx, const NDArray& rsp,
+                           NDArray* output) {
+  CHECK_EQ(rsp.storage_type(), output->storage_type())
+      << "Copying with different storage type";
+  mshadow::Stream<xpu>* s = ctx.get_stream<xpu>();
+  if (!rsp.storage_initialized()) {
+    FillZerosRspImpl(s, *output);
+    return;
+  }
+  auto aux_shape = rsp.aux_shape(rowsparse::kIdx);
+  output->CheckAndAlloc({aux_shape});
+  const TBlob& val = output->data();
+  const TBlob& idx = output->aux_data(rowsparse::kIdx);
+  const TBlob& from_val = rsp.data();
+  const TBlob& from_idx = rsp.aux_data(rowsparse::kIdx);
+  mxnet_op::copy(s, val, from_val);
+  mxnet_op::copy(s, idx, from_idx);
+}
+
 template<typename xpu>
 void CastStorageComputeImpl(const OpContext& ctx,
                             const NDArray& input,
@@ -346,6 +391,12 @@ void CastStorageComputeImpl(const OpContext& ctx,
   } else if (src_stype == kCSRStorage && dst_stype == kDefaultStorage) {
     TBlob ret = output.data();
     CastStorageCsrDnsImpl<xpu>(ctx, input, &ret);
+  } else if (src_stype == kCSRStorage && dst_stype == kCSRStorage) {
+    NDArray ret = output;
+    CastStorageCsrCsrImpl<xpu>(ctx, input, &ret);
+  } else if (src_stype == kRowSparseStorage && dst_stype == kRowSparseStorage) 
{
+    NDArray ret = output;
+    CastStorageRspRspImpl<xpu>(ctx, input, &ret);
 #if MXNET_USE_MKLDNN == 1
   } else if (src_stype == kDefaultStorage && dst_stype == kDefaultStorage) {
     CHECK_EQ(output.ctx().dev_type, input.ctx().dev_type);
diff --git a/src/operator/tensor/cast_storage.cc 
b/src/operator/tensor/cast_storage.cc
index 9f257b1..f77a50a 100644
--- a/src/operator/tensor/cast_storage.cc
+++ b/src/operator/tensor/cast_storage.cc
@@ -46,6 +46,8 @@ The storage type of ``cast_storage`` output depends on stype 
parameter:
 - cast_storage(row_sparse, 'default') = default
 - cast_storage(default, 'csr') = csr
 - cast_storage(default, 'row_sparse') = row_sparse
+- cast_storage(csr, 'csr') = csr
+- cast_storage(row_sparse, 'row_sparse') = row_sparse
 
 Example::
 
diff --git a/tests/python/unittest/test_sparse_operator.py 
b/tests/python/unittest/test_sparse_operator.py
index 9417df3..5ad5215 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -1177,10 +1177,13 @@ def test_cast_storage_ex():
         shape_3d = rand_shape_3d()
         check_cast_storage(shape_2d, d, 'csr', 'default')
         check_cast_storage(shape_2d, d, 'default', 'csr')
+        check_cast_storage(shape_2d, d, 'csr', 'csr')
         check_cast_storage(shape_2d, d, 'row_sparse', 'default')
         check_cast_storage(shape_2d, d, 'default', 'row_sparse')
+        check_cast_storage(shape_2d, d, 'row_sparse', 'row_sparse')
         check_cast_storage(shape_3d, d, 'row_sparse', 'default')
         check_cast_storage(shape_3d, d, 'default', 'row_sparse')
+        check_cast_storage(shape_3d, d, 'row_sparse', 'row_sparse')
         for i in range(4, 6):
             shape = rand_shape_nd(i, 5)
             check_cast_storage(shape, d, 'default', 'row_sparse')

-- 
To stop receiving notification emails like this one, please contact
hai...@apache.org.

Reply via email to