This is an automated email from the ASF dual-hosted git repository.

wkcn pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new f01dc80  Adding sparse support to MXTensor for custom operators 
(#17569)
f01dc80 is described below

commit f01dc80f030d2d1912c8e134c95f373e9f1f8e7b
Author: guanxinq <58794120+guanx...@users.noreply.github.com>
AuthorDate: Sun Mar 22 03:50:55 2020 -0700

    Adding sparse support to MXTensor for custom operators (#17569)
    
    * Added enum for sparse storage
    
    * Add structure for Dense and Sparse
    
    * redesign the data structure for MXSparse
    
    * pull out aux data from sparse NDArray
    
    * Added more sparse arguments to API interface
    
    * Passed sparse from c_api to lib_api.h and set in MXTensor
    
    * Fix indent
    
    * fix segfault
    
    * Fix NDArray to MXTensor errors
    
    * Add a sample of sparse(CSR) transpose
    
    * Make CSR transpose temporarily work by hardcoding
    
    * Fixed sparse output size(Refined)
    
    * Add tests for symbolic and stateful ops
    
    * Added a sample for row sparse transpose
    
    * Added real row sparse transpose
    
    * Fix output size issue by adding lambda for CheckAndAlloc()
    
    * Fix mixed storage formats error
    
    * Added infer storage type function
    
    * resolve comments
    
    * Set inferSType as optional function
    
    * Resolve comments
    
    * Add error messages
    
    * Resolve comments
    
    * verify transpose ops results
    
    * fix sanity check
    
    * update MX_LIBRARY_VERSION to 5
---
 example/extensions/lib_custom_op/Makefile          |  10 +-
 .../extensions/lib_custom_op/test_transposecsr.py  |  78 ++++++
 .../lib_custom_op/test_transposerowsp.py           |  73 ++++++
 .../extensions/lib_custom_op/transposecsr_lib.cc   | 197 ++++++++++++++
 .../extensions/lib_custom_op/transposerowsp_lib.cc | 199 ++++++++++++++
 example/extensions/lib_subgraph/subgraph_lib.cc    |   4 +-
 include/mxnet/lib_api.h                            | 286 ++++++++++++++++++---
 src/c_api/c_api.cc                                 | 119 ++++++++-
 8 files changed, 919 insertions(+), 47 deletions(-)

diff --git a/example/extensions/lib_custom_op/Makefile 
b/example/extensions/lib_custom_op/Makefile
index edd753b..feded29 100644
--- a/example/extensions/lib_custom_op/Makefile
+++ b/example/extensions/lib_custom_op/Makefile
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-all: gemm_lib relu_lib
+all: gemm_lib relu_lib transposecsr_lib transposerowsp_lib
 
 gemm_lib:
        g++ -shared -fPIC -std=c++11 gemm_lib.cc -o libgemm_lib.so -I 
../../../include/mxnet
@@ -23,5 +23,11 @@ gemm_lib:
 relu_lib:
        nvcc -shared -std=c++11 -Xcompiler -fPIC relu_lib.cu -o librelu_lib.so 
-I ../../../include/mxnet
 
+transposecsr_lib:
+       g++ -shared -fPIC -std=c++11 transposecsr_lib.cc -o 
libtransposecsr_lib.so -I ../../../include/mxnet
+
+transposerowsp_lib:
+       g++ -shared -fPIC -std=c++11 transposerowsp_lib.cc -o 
libtransposerowsp_lib.so -I ../../../include/mxnet
+
 clean:
-       rm -rf libgemm_lib.so librelu_lib.so
+       rm -rf libgemm_lib.so librelu_lib.so libtransposecsr_lib.so 
libtransposerowsp_lib.so
diff --git a/example/extensions/lib_custom_op/test_transposecsr.py 
b/example/extensions/lib_custom_op/test_transposecsr.py
new file mode 100644
index 0000000..37d066a
--- /dev/null
+++ b/example/extensions/lib_custom_op/test_transposecsr.py
@@ -0,0 +1,78 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=arguments-differ
+
+# This test checks dynamic loading of custom library into MXNet
+# and checks end to end compute of a simple 2D gemm custom op
+
+import mxnet as mx
+import os
+
+#load library
+if (os.name=='posix'):
+    path = os.path.abspath('libtransposecsr_lib.so')
+    mx.library.load(path)
+elif (os.name=='nt'):
+    path = os.path.abspath('libtransposecsr_lib.dll')
+    mx.library.load(path)
+
+a = mx.nd.array([[1,3,0,2,1],[0,1,0,0,0],[0,2,4,5,3]])
+a = a.tostype('csr')
+print("--------Input CSR Array---------")
+print("data:", a.data.asnumpy())
+print("indices:", a.indices.asnumpy())
+print("indptr:", a.indptr.asnumpy())
+
+print("--------Start NDArray Compute---------")
+b = mx.nd.my_transposecsr(a)
+print("Compute Results:")
+print("data:", b.data.asnumpy())
+print("indices:", b.indices.asnumpy())
+print("indptr:", b.indptr.asnumpy())
+
+print("Stateful Compute Result:")
+c = mx.nd.my_state_transposecsr(a, test_kw=100)
+print("data:", c.data.asnumpy())
+print("indices:", c.indices.asnumpy())
+print("indptr:", c.indptr.asnumpy())
+
+print("--------start symbolic compute--------")
+d = mx.sym.Variable('d')
+e = mx.sym.my_transposecsr(d)
+f = mx.sym.my_state_transposecsr(d, test_kw=200)
+
+exe = e.bind(ctx=mx.cpu(),args={'d':a})
+exe2 = f.bind(ctx=mx.cpu(),args={'d':a})
+out = exe.forward()
+print("Compute Results:")
+print("data:", out[0].data.asnumpy())
+print("indices:", out[0].indices.asnumpy())
+print("indptr:", out[0].indptr.asnumpy())
+
+out2 = exe2.forward()
+out2 = exe2.forward()
+print("Stateful Compute Result:")
+print("data:", out2[0].data.asnumpy())
+print("indices:", out2[0].indices.asnumpy())
+print("indptr:", out2[0].indptr.asnumpy())
+
+print("--------Baseline(dense)--------")
+print(mx.nd.transpose(a.tostype('default')))
diff --git a/example/extensions/lib_custom_op/test_transposerowsp.py 
b/example/extensions/lib_custom_op/test_transposerowsp.py
new file mode 100644
index 0000000..cea62ec
--- /dev/null
+++ b/example/extensions/lib_custom_op/test_transposerowsp.py
@@ -0,0 +1,73 @@
+#!/usr/bin/env python3
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=arguments-differ
+
+# This test checks dynamic loading of custom library into MXNet
+# and checks end to end compute of a simple 2D gemm custom op
+
+import mxnet as mx
+import os
+
+#load library
+if (os.name=='posix'):
+    path = os.path.abspath('libtransposerowsp_lib.so')
+    mx.library.load(path)
+elif (os.name=='nt'):
+    path = os.path.abspath('libtransposerowsp_lib.dll')
+    mx.library.load(path)
+
+a = mx.nd.array([[1,2,3],[0,0,0],[4,0,5],[0,0,0],[0,0,0]])
+a = a.tostype('row_sparse')
+print("--------Input CSR Array---------")
+print("data:", a.data.asnumpy())
+print("indices:", a.indices.asnumpy())
+
+print("--------Start NDArray Compute---------")
+b = mx.nd.my_transposerowsp(a)
+print("Compute Results:")
+print("data:", b.data.asnumpy())
+print("indices:", b.indices.asnumpy())
+
+print("Stateful Compute Result:")
+c = mx.nd.my_state_transposerowsp(a, test_kw=100)
+print("data:", c.data.asnumpy())
+print("indices:", c.indices.asnumpy())
+
+print("--------start symbolic compute--------")
+d = mx.sym.Variable('d')
+e = mx.sym.my_transposerowsp(d)
+f = mx.sym.my_state_transposerowsp(d, test_kw=200)
+
+exe = e.bind(ctx=mx.cpu(),args={'d':a})
+exe2 = f.bind(ctx=mx.cpu(),args={'d':a})
+out = exe.forward()
+print("Compute Results:")
+print("data:", out[0].data.asnumpy())
+print("indices:", out[0].indices.asnumpy())
+
+out2 = exe2.forward()
+out2 = exe2.forward()
+print("Stateful Compute Result:")
+print("data:", out2[0].data.asnumpy())
+print("indices:", out2[0].indices.asnumpy())
+
+print("--------Baseline(dense)--------")
+print(mx.nd.transpose(a.tostype('default')))
diff --git a/example/extensions/lib_custom_op/transposecsr_lib.cc 
b/example/extensions/lib_custom_op/transposecsr_lib.cc
new file mode 100644
index 0000000..0daeb3e
--- /dev/null
+++ b/example/extensions/lib_custom_op/transposecsr_lib.cc
@@ -0,0 +1,197 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file transsparse_lib.cc
+ * \brief Sample 2D transpose custom operator.
+ */
+
+#include <iostream>
+#include "lib_api.h"
+
+void transpose(MXTensor src, MXTensor dst, OpResource res) {
+  MXSparse* A = src.data<MXSparse>();
+  MXSparse* B = dst.data<MXSparse>(); 
+  std::vector<int64_t> shape = src.shape;
+  int64_t h = shape[0];
+  int64_t w = shape[1];
+  if(src.stype == kCSRStorage) {
+    float *Aval = (float*) (A->data);
+    // Here we need one more element to help calculate index(line 57).
+    std::vector<int64_t> rowPtr(w + 2, 0);
+    // count column
+    for(int i = 0; i < A->data_len; i++) {
+      rowPtr[A->indices[i] + 2]++;
+    }
+    // Accumulated sum. After this for loop, rowPtr[1:w+2) stores the correct 
+    // result of transposed rowPtr.
+    for(int i = 2; i < rowPtr.size(); i++) {
+      rowPtr[i] += rowPtr[i - 1];
+    }
+    
+    // Alloc memory for sparse data, where 0 is the index
+    // of B in output vector.
+    res.alloc_sparse(B, 0, A->data_len, w + 1);
+    float *Bval = (float*) (B->data);
+    for(int i = 0; i < h; i++) {
+      for(int j = A->indptr[i]; j < A->indptr[i + 1]; j++) {
+        // Helps calculate index and after that rowPtr[0:w+1) stores the 
+        // correct result of transposed rowPtr.
+        int index = rowPtr[A->indices[j] + 1]++;
+        Bval[index] = Aval[j];
+        B->indices[index] = i;
+      }
+    }
+    memcpy(B->indptr, rowPtr.data(), sizeof(int64_t) * (w + 1));
+  }
+}
+
+MXReturnValue forward(std::map<std::string, std::string> attrs,
+                      std::vector<MXTensor> inputs,
+                      std::vector<MXTensor> outputs,
+                      OpResource res) {
+  // The data types and storage types of inputs and outputs should be the 
same.  
+  if(inputs[0].dtype != outputs[0].dtype || inputs[0].stype != 
outputs[0].stype) {
+    std::cout << "Error! Expected all inputs and outputs to be the same type." 
+              << "Found input storage type:" << inputs[0].stype
+              << " Found output storage type:" << outputs[0].stype
+              << " Found input data type:" << inputs[0].dtype
+              << " Found output data type:" << outputs[0].dtype << std::endl;
+    return MX_FAIL;
+  }
+
+  transpose(inputs[0], outputs[0], res);
+  return MX_SUCCESS;
+}
+
+MXReturnValue backward(std::map<std::string, std::string> attrs,
+                       std::vector<MXTensor> inputs,
+                       std::vector<MXTensor> outputs,
+                       OpResource res) {
+  return MX_SUCCESS;
+}
+
+MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int* 
num_in, int* num_out) {
+  *num_in = 1;
+  *num_out = 1;
+  return MX_SUCCESS;
+}
+
+MXReturnValue inferType(std::map<std::string, std::string> attrs,
+                        std::vector<int> &intypes,
+                        std::vector<int> &outtypes) {
+  // validate inputs
+  if (intypes.size() != 1) {
+    std::cout << "Expected 1 inputs to inferType" << std::endl;
+    return MX_FAIL;
+  }
+  if (intypes[0] != kFloat32) {
+    std::cout << "Expected input to have float32 type" << std::endl;
+    return MX_FAIL;
+  }
+
+  outtypes[0] = intypes[0];
+  return MX_SUCCESS;
+}
+
+MXReturnValue inferSType(std::map<std::string, std::string> attrs,
+                        std::vector<int> &instypes,
+                        std::vector<int> &outstypes) {
+  if (instypes[0] != kCSRStorage) {
+    std::cout << "Expected storage type is kCSRStorage" << std::endl;
+    return MX_FAIL;
+  }
+  outstypes[0] = instypes[0];
+  return MX_SUCCESS;
+}
+
+MXReturnValue inferShape(std::map<std::string, std::string> attrs,
+                         std::vector<std::vector<unsigned int>> &inshapes,
+                         std::vector<std::vector<unsigned int>> &outshapes) {
+  // validate inputs
+  if (inshapes.size() != 1) {
+    std::cout << "Expected 1 inputs to inferShape" << std::endl;
+    return MX_FAIL;
+  }
+
+  outshapes[0].push_back(inshapes[0][1]);
+  outshapes[0].push_back(inshapes[0][0]);
+  return MX_SUCCESS;
+}
+
+REGISTER_OP(my_transposecsr)
+.setForward(forward, "cpu")
+.setBackward(backward, "cpu")
+.setParseAttrs(parseAttrs)
+.setInferType(inferType)
+.setInferSType(inferSType)
+.setInferShape(inferShape);
+
+/* ------------------------------------------------------------------------- */
+
+class MyStatefulTransposeCSR : public CustomStatefulOp {
+ public:
+  explicit MyStatefulTransposeCSR(int count) : count(count) {}
+
+  MXReturnValue Forward(std::vector<MXTensor> inputs,
+                        std::vector<MXTensor> outputs,
+                        OpResource op_res) {
+    std::cout << "Info: keyword + number of forward: " << ++count << std::endl;
+    std::map<std::string, std::string> attrs;
+    return forward(attrs, inputs, outputs, op_res);
+  }
+
+  MXReturnValue Backward(std::vector<MXTensor> inputs,
+                         std::vector<MXTensor> outputs,
+                         OpResource op_res) {
+    std::map<std::string, std::string> attrs;
+    return backward(attrs, inputs, outputs, op_res);
+  }
+
+ private:
+  int count;
+};
+
+MXReturnValue createOpState(std::map<std::string, std::string> attrs,
+                            CustomStatefulOp** op_inst) {
+  // testing passing of keyword arguments
+  int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0;
+  // creating stateful operator instance
+  *op_inst = new MyStatefulTransposeCSR(count);
+  std::cout << "Info: stateful operator created" << std::endl;
+  return MX_SUCCESS;
+}
+
+REGISTER_OP(my_state_transposecsr)
+.setParseAttrs(parseAttrs)
+.setInferType(inferType)
+.setInferSType(inferSType)
+.setInferShape(inferShape)
+.setCreateOpState(createOpState, "cpu");
+
+MXReturnValue initialize(int version) {
+  if (version >= 10400) {
+    std::cout << "MXNet version " << version << " supported" << std::endl;
+    return MX_SUCCESS;
+  } else {
+    std::cout << "MXNet version " << version << " not supported" << std::endl;
+    return MX_FAIL;
+  }
+}
diff --git a/example/extensions/lib_custom_op/transposerowsp_lib.cc 
b/example/extensions/lib_custom_op/transposerowsp_lib.cc
new file mode 100644
index 0000000..883d816
--- /dev/null
+++ b/example/extensions/lib_custom_op/transposerowsp_lib.cc
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file transsparse_lib.cc
+ * \brief Sample 2D transpose custom operator.
+ */
+
+#include <iostream>
+#include "lib_api.h"
+
+void transpose(MXTensor src, MXTensor dst, OpResource res) {
+  MXSparse* A = src.data<MXSparse>();
+  MXSparse* B = dst.data<MXSparse>(); 
+
+  std::vector<int64_t> shape = src.shape;
+  int64_t h = shape[0];
+  int64_t w = shape[1];
+  if(src.stype == kRowSparseStorage) {
+    // Keys of the map is the row index of transposed tensors.
+    // Values of the map is the rows which have non-zero elements.    
+    std::map<int, std::vector<float>> mp;
+    float *Aval = (float*) (A->data);
+    for(int i = 0; i < A->data_len; i++) {
+      int row = i / w;
+      int col = i % w;
+      row = A->indices[row];
+      if(Aval[i] != 0) {
+        if(mp.find(col) == mp.end()) {
+          mp[col] = std::vector<float>(h, 0);
+          mp[col][row] = Aval[i];
+        }
+        else {
+          mp[col][row] = Aval[i];
+        }
+      }
+    }
+
+    // Alloc memory for output tensors.
+    res.alloc_sparse(B, 0, mp.size());
+    float *Bval = (float*) (B->data);
+    int didx = 0, iidx = 0;
+    for(auto i : mp) {
+      B->indices[iidx++] = i.first;
+      for(auto j : i.second) {
+        Bval[didx++] = j;
+      }
+    }
+  }
+}
+
+MXReturnValue forward(std::map<std::string, std::string> attrs,
+                      std::vector<MXTensor> inputs,
+                      std::vector<MXTensor> outputs,
+                      OpResource res) {
+  // The data types and storage types of inputs and outputs should be the same.
+  if(inputs[0].dtype != outputs[0].dtype || inputs[0].stype != 
outputs[0].stype) {
+    std::cout << "Error! Expected all inputs and outputs to be the same type."
+              << "Found input storage type:" << inputs[0].stype
+              << " Found output storage type:" << outputs[0].stype
+              << " Found input data type:" << inputs[0].dtype
+              << " Found output data type:" << outputs[0].dtype << std::endl;
+    return MX_FAIL;
+  }
+  transpose(inputs[0], outputs[0], res);
+  return MX_SUCCESS;
+}
+
+MXReturnValue backward(std::map<std::string, std::string> attrs,
+                       std::vector<MXTensor> inputs,
+                       std::vector<MXTensor> outputs,
+                       OpResource res) {
+  return MX_SUCCESS;
+}
+
+MXReturnValue parseAttrs(std::map<std::string, std::string> attrs, int* 
num_in, int* num_out) {
+  *num_in = 1;
+  *num_out = 1;
+  return MX_SUCCESS;
+}
+
+MXReturnValue inferType(std::map<std::string, std::string> attrs,
+                        std::vector<int> &intypes,
+                        std::vector<int> &outtypes) {
+  // validate inputs
+  if (intypes.size() != 1) {
+    std::cout << "Expected 1 inputs to inferType" << std::endl;
+    return MX_FAIL;
+  }
+  if (intypes[0] != kFloat32) {
+    std::cout << "Expected input to have float32 type" << std::endl;
+    return MX_FAIL;
+  }
+
+  outtypes[0] = intypes[0];
+  return MX_SUCCESS;
+}
+
+MXReturnValue inferSType(std::map<std::string, std::string> attrs,
+                        std::vector<int> &instypes,
+                        std::vector<int> &outstypes) {
+  if (instypes[0] != kRowSparseStorage) {
+    std::cout << "Expected storage type is kRowSparseStorage" << std::endl;
+    return MX_FAIL;
+  }
+  outstypes[0] = instypes[0];
+  return MX_SUCCESS;
+}
+
+MXReturnValue inferShape(std::map<std::string, std::string> attrs,
+                         std::vector<std::vector<unsigned int>> &inshapes,
+                         std::vector<std::vector<unsigned int>> &outshapes) {
+  // validate inputs
+  if (inshapes.size() != 1) {
+    std::cout << "Expected 1 inputs to inferShape" << std::endl;
+    return MX_FAIL;
+  }
+
+  outshapes[0].push_back(inshapes[0][1]);
+  outshapes[0].push_back(inshapes[0][0]);
+  return MX_SUCCESS;
+}
+
+REGISTER_OP(my_transposerowsp)
+.setForward(forward, "cpu")
+.setBackward(backward, "cpu")
+.setParseAttrs(parseAttrs)
+.setInferType(inferType)
+.setInferSType(inferSType)
+.setInferShape(inferShape);
+
+/* ------------------------------------------------------------------------- */
+
+class MyStatefulTransposeRowSP : public CustomStatefulOp {
+ public:
+  explicit MyStatefulTransposeRowSP(int count) : count(count) {}
+
+  MXReturnValue Forward(std::vector<MXTensor> inputs,
+                        std::vector<MXTensor> outputs,
+                        OpResource op_res) {
+    std::cout << "Info: keyword + number of forward: " << ++count << std::endl;
+    std::map<std::string, std::string> attrs;
+    return forward(attrs, inputs, outputs, op_res);
+  }
+
+  MXReturnValue Backward(std::vector<MXTensor> inputs,
+                         std::vector<MXTensor> outputs,
+                         OpResource op_res) {
+    std::map<std::string, std::string> attrs;
+    return backward(attrs, inputs, outputs, op_res);
+  }
+
+ private:
+  int count;
+};
+
+MXReturnValue createOpState(std::map<std::string, std::string> attrs,
+                            CustomStatefulOp** op_inst) {
+  // testing passing of keyword arguments
+  int count = attrs.count("test_kw") > 0 ? std::stoi(attrs["test_kw"]) : 0;
+  // creating stateful operator instance
+  *op_inst = new MyStatefulTransposeRowSP(count);
+  std::cout << "Info: stateful operator created" << std::endl;
+  return MX_SUCCESS;
+}
+
+REGISTER_OP(my_state_transposerowsp)
+.setParseAttrs(parseAttrs)
+.setInferType(inferType)
+.setInferSType(inferSType)
+.setInferShape(inferShape)
+.setCreateOpState(createOpState, "cpu");
+
+MXReturnValue initialize(int version) {
+  if (version >= 10400) {
+    std::cout << "MXNet version " << version << " supported" << std::endl;
+    return MX_SUCCESS;
+  } else {
+    std::cout << "MXNet version " << version << " not supported" << std::endl;
+    return MX_FAIL;
+  }
+}
diff --git a/example/extensions/lib_subgraph/subgraph_lib.cc 
b/example/extensions/lib_subgraph/subgraph_lib.cc
index 8c24dd8..d821bdb 100644
--- a/example/extensions/lib_subgraph/subgraph_lib.cc
+++ b/example/extensions/lib_subgraph/subgraph_lib.cc
@@ -84,7 +84,7 @@ MXReturnValue myExecutor(std::vector<MXTensor> inputs,
       // get input tensor based on node ID inputs from data storage
       MXTensor &input = data[node_inputs.list[0].list[0].num];
       // create temporary storage
-      MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, 
{"cpu", 0});
+      MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, 
{"cpu", 0}, kDefaultStorage);
       // save allocated ptr to free later
       to_free.push_back(tmp.data_ptr);
       // execute log operator
@@ -95,7 +95,7 @@ MXReturnValue myExecutor(std::vector<MXTensor> inputs,
       // get input tensor based on node ID inputs from data storage
       MXTensor &input = data[node_inputs.list[0].list[0].num];
       // create temporary storage
-      MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, 
{"cpu", 0});
+      MXTensor tmp(malloc(input.size()*4), input.shape, input.dtype, 0, 
{"cpu", 0}, kDefaultStorage);
       // save allocated ptr to free later
       to_free.push_back(tmp.data_ptr);
       // execute exp operator 
diff --git a/include/mxnet/lib_api.h b/include/mxnet/lib_api.h
index 9b32122..fd526ee 100644
--- a/include/mxnet/lib_api.h
+++ b/include/mxnet/lib_api.h
@@ -39,7 +39,7 @@
 #include <utility>
 #include <stdexcept>
 
-#define MX_LIBRARY_VERSION 4
+#define MX_LIBRARY_VERSION 5
 
 /*!
  * \brief For loading multiple custom op libraries in Linux, exporting same 
symbol multiple
@@ -214,6 +214,18 @@ enum MXDType {
   kUNSET = 100,
 };
 
+/*
+ * MXTensor storage type.
+ */
+enum MXStorageType {
+  // dense
+  kDefaultStorage = 0,
+  // row sparse
+  kRowSparseStorage = 1,
+  // csr
+  kCSRStorage = 2,
+};
+
 /*!
  * \brief Context info passing from MXNet OpContext
  * dev_type is string repr of supported context, currently only "cpu" and "gpu"
@@ -229,25 +241,64 @@ enum MXReturnValue {
   MX_SUCCESS = 1,
 };
 
+// For sparse tensors, read/write the data from NDarray via pointers.
+struct MXSparse {
+  // Pointer to data.
+  void *data{nullptr};
+  // length of (non-zero) data.
+  int64_t data_len;
+
+  // To store aux data for sparse.
+  // For CSR, indices stores the col index of non-zero elements.
+  // For row sparse, indices store row index of rows which have non-zero 
elements.
+  int64_t* indices;
+  int64_t indices_len;
+
+  // For CSR, indptr gives the start and end index of data for each row.
+  // For row sparse, indptr is not used.
+  int64_t* indptr = nullptr;
+  int64_t indptr_len;
+
+  void set(void *data_ptr, const int64_t* dims, int ndims, void *idx,
+          int64_t num_idx, void *idx_ptr = nullptr, int64_t num_idx_ptr = 0) {
+    data = data_ptr;
+    // If CSR, num of non-zero elemets is num_idx,
+    // If row sparse, num of elements is num_idx * width.
+    data_len = num_idx;
+    if (!idx_ptr) {
+      for (int i = 1; i < ndims; ++i)
+         data_len *= dims[i];
+    }
+
+    indices = reinterpret_cast<int64_t*>(idx);
+    indices_len = num_idx;
+
+    if (idx_ptr) {
+      indptr = reinterpret_cast<int64_t*>(idx_ptr);
+      indptr_len = num_idx_ptr;
+    }
+  }
+};
+
 /*!
  * \brief Tensor data structure used by custom operator
  */
 struct MXTensor {
-  MXTensor() : data_ptr(nullptr), dtype(kUNSET), verID(0) {}
+  MXTensor() : data_ptr(nullptr), dtype(kUNSET), verID(0), 
stype(kDefaultStorage) {}
   MXTensor(const MXTensor& oth) : data_ptr(oth.data_ptr), shape(oth.shape),
-    dtype(oth.dtype), verID(oth.verID), ctx(oth.ctx) {
+    dtype(oth.dtype), verID(oth.verID), ctx(oth.ctx), stype(oth.stype) {
     setDLTensor();
   }
   MXTensor(void *data_ptr, const std::vector<int64_t> &shape, MXDType dtype,
-           size_t vID, MXContext mx_ctx)
-  : data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID), ctx(mx_ctx) {
+           size_t vID, MXContext mx_ctx, MXStorageType stype = kDefaultStorage)
+  : data_ptr(data_ptr), shape(shape), dtype(dtype), verID(vID), ctx(mx_ctx), 
stype(stype) {
     setDLTensor();
   }
 
   /*! \brief populate internal tensor fields */
   void setTensor(void *dptr, MXDType type, const int64_t* dims, int ndims,
-                 size_t vID, MXContext mx_ctx) {
-    data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx;
+                 size_t vID, MXContext mx_ctx, MXStorageType storage_type) {
+    data_ptr = dptr; dtype = type; verID = vID; ctx = mx_ctx; stype = 
storage_type;
     shape.clear();
     for (int j = 0; j < ndims; j++) {
       shape.push_back(dims[j]);
@@ -340,11 +391,12 @@ struct MXTensor {
            verID == oth.verID &&
            ctx.dev_type == oth.ctx.dev_type &&
            ctx.dev_id == oth.ctx.dev_id &&
-           shape == oth.shape;
+           shape == oth.shape &&
+           stype == oth.stype;
   }
 
-  // data is flatten 1D repr of tensor, elements are in continuous memory
-  // user can access each element using the shape of tensor
+  // For dense, data_ptr points to data.
+  // For sparse, data_ptr points to MXSparse.
   void *data_ptr;
 
   // shape is in [2,3,4] format to represent high-dim tensor
@@ -362,11 +414,16 @@ struct MXTensor {
   // corresponding DLTensor repr of MXTensor
   // easy way to reuse functions taking DLTensor
   DLTensor dltensor;
+
+  // storage type
+  MXStorageType stype;
 };
 
 /*! \brief resource malloc function to allocate memory inside Forward/Backward 
functions */
 typedef void* (*xpu_malloc_t)(void*, int);
 
+typedef void (*sparse_malloc_t)(void*, int, int, int, void**, int64_t**, 
int64_t**);
+
 #if defined(__NVCC__)
   typedef cudaStream_t mx_stream_t;
 #else
@@ -379,9 +436,11 @@ typedef void* (*xpu_malloc_t)(void*, int);
 class OpResource {
  public:
   OpResource(xpu_malloc_t cpu_malloc_fp, void* cpu_alloc_fp,
-             xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream)
+             xpu_malloc_t gpu_malloc_fp, void* gpu_alloc_fp, void* stream,
+             sparse_malloc_t sparse_malloc_fp, void* sparse_alloc_fp)
     : cpu_malloc(cpu_malloc_fp), gpu_malloc(gpu_malloc_fp),
-      cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream) {}
+      cpu_alloc(cpu_alloc_fp), gpu_alloc(gpu_alloc_fp), cuda_stream(stream),
+      sparse_malloc(sparse_malloc_fp), sparse_alloc(sparse_alloc_fp) {}
 
   /*! \brief allocate cpu memory controlled by MXNet */
   void* alloc_cpu(int size) {
@@ -398,6 +457,12 @@ class OpResource {
     return static_cast<mx_stream_t>(cuda_stream);
   }
 
+  /*! \brief allocate sparse memory controlled by MXNet */
+  void alloc_sparse(MXSparse* sparse, int index, int indices_len, int 
indptr_len = 0) {
+    sparse_malloc(sparse_alloc, index, indices_len, indptr_len,
+                   &(sparse->data), &(sparse->indices), &(sparse->indptr));
+  }
+
  private:
   /*! \brief allocation lambda function */
   xpu_malloc_t cpu_malloc, gpu_malloc;
@@ -405,6 +470,10 @@ class OpResource {
   void *cpu_alloc, *gpu_alloc;
   /*! \brief cuda stream passed from MXNet */
   void *cuda_stream;
+  /*! \brief sparse allocation lambda function */
+  sparse_malloc_t sparse_malloc;
+  /*! \brief lambda function to return allocated sparse memory handle */
+  void *sparse_alloc;
 };
 
 /*!
@@ -647,6 +716,8 @@ typedef MXReturnValue (*parseAttrs_t)(std::map<std::string, 
std::string>,
                                       int*, int*);
 typedef MXReturnValue (*inferType_t)(std::map<std::string, std::string>,
                                      std::vector<int>&, std::vector<int>&);
+typedef MXReturnValue (*inferSType_t)(std::map<std::string, std::string>,
+                                     std::vector<int>&, std::vector<int>&);
 typedef MXReturnValue (*inferShape_t)(std::map<std::string, std::string>,
                                       std::vector<std::vector<unsigned int> >&,
                                       std::vector<std::vector<unsigned int> 
>&);
@@ -660,9 +731,9 @@ typedef MXReturnValue 
(*createOpState_t)(std::map<std::string, std::string>,
  */
 class CustomOp {
  public:
-  explicit CustomOp(const char* op_name) :
-      name(op_name), parse_attrs(nullptr), infer_type(nullptr),
-      infer_shape(nullptr), mutate_inputs(nullptr), isSGop(false) {}
+  explicit CustomOp(const char* op_name) : name(op_name),
+    parse_attrs(NULL), infer_type(NULL), infer_storage_type(NULL), 
infer_shape(NULL),
+    mutate_inputs(NULL), isSGop(false) {}
   CustomOp& setForward(fcomp_t fcomp, const char* ctx) {
     if (forward_ctx_map.count(ctx) > 0)
       raiseDuplicateContextError();
@@ -683,6 +754,10 @@ class CustomOp {
     infer_type = func;
     return *this;
   }
+  CustomOp& setInferSType(inferSType_t func) {
+    infer_storage_type = func;
+    return *this;
+  }
   CustomOp& setInferShape(inferShape_t func) {
     infer_shape = func;
     return *this;
@@ -723,6 +798,7 @@ class CustomOp {
   /*! \brief operator functions */
   parseAttrs_t parse_attrs;
   inferType_t infer_type;
+  inferSType_t infer_storage_type;
   inferShape_t infer_shape;
   mutateInputs_t mutate_inputs;
   bool isSGop;
@@ -876,7 +952,7 @@ typedef int (*opRegGet_t)(int idx, const char** name, int 
*isSGop,
                           const char*** backward_ctx, fcomp_t** backward_fp, 
int* backward_count,
                           const char*** create_op_ctx, createOpState_t** 
create_op_fp,
                           int* create_op_count,
-                          parseAttrs_t* parse, inferType_t* type,
+                          parseAttrs_t* parse, inferType_t* type, 
inferSType_t* stype,
                           inferShape_t* shape, mutateInputs_t* mutate);
 
 #define MXLIB_OPCALLFREE_STR "_opCallFree"
@@ -898,6 +974,11 @@ typedef int (*opCallInferType_t)(inferType_t inferType, 
const char* const* keys,
                                  const char* const* vals, int num,
                                  int* intypes, int num_in, int* outtypes, int 
num_out);
 
+#define MXLIB_OPCALLINFERSTYPE_STR "_opCallInferSType"
+typedef int (*opCallInferSType_t)(inferSType_t inferSType, const char* const* 
keys,
+                                 const char* const* vals, int num,
+                                 int* intypes, int num_in, int* outtypes, int 
num_out);
+
 #define MXLIB_OPCALLFCOMP_STR "_opCallFCompute"
 typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* const* keys,
                              const char* const* vals, int num,
@@ -910,7 +991,13 @@ typedef int (*opCallFComp_t)(fcomp_t fcomp, const char* 
const* keys,
                              size_t* outIDs, const char** outdev_type,
                              int* outdev_id, int num_out,
                              xpu_malloc_t cpu_malloc, void* cpu_alloc,
-                             xpu_malloc_t gpu_malloc, void* gpu_alloc, void* 
cuda_stream);
+                             xpu_malloc_t gpu_malloc, void* gpu_alloc, void* 
cuda_stream,
+                             sparse_malloc_t sparse_malloc, void* sparse_alloc,
+                             int* instypes, int* outstypes,
+                             void** in_indices, void** out_indices,
+                             void** in_indptr, void** out_indptr,
+                             int64_t* in_indices_shapes, int64_t* 
out_indices_shapes,
+                             int64_t* in_indptr_shapes, int64_t* 
out_indptr_shapes);
 
 #define MXLIB_OPCALLMUTATEINPUTS_STR "_opCallMutateInputs"
 typedef int (*opCallMutateInputs_t)(mutateInputs_t mutate, const char* const* 
keys,
@@ -933,7 +1020,13 @@ typedef int (*opCallFStatefulComp_t)(int is_forward, 
void* state_op,
                                      size_t* outIDs, const char** outdev_type,
                                      int* outdev_id, int num_out,
                                      xpu_malloc_t cpu_malloc, void* cpu_alloc,
-                                     xpu_malloc_t gpu_malloc, void* gpu_alloc, 
void* stream);
+                                     xpu_malloc_t gpu_malloc, void* gpu_alloc, 
void* stream,
+                                     sparse_malloc_t sparse_malloc, void* 
sparse_alloc,
+                                     int* instypes, int* outstypes,
+                                     void** in_indices, void** out_indices,
+                                     void** in_indptr, void** out_indptr,
+                                     int64_t* in_indices_shapes, int64_t* 
out_indices_shapes,
+                                     int64_t* in_indptr_shapes, int64_t* 
out_indptr_shapes);
 
 #define MXLIB_PARTREGSIZE_STR "_partRegSize"
 typedef int (*partRegSize_t)(void);
@@ -1004,12 +1097,13 @@ extern "C" {
             const char*** forward_ctx, fcomp_t** forward_fp, int* 
forward_count,
             const char*** backward_ctx, fcomp_t** backward_fp, int* 
backward_count,
             const char*** create_op_ctx, createOpState_t** create_op_fp, int* 
create_op_count,
-            parseAttrs_t* parse, inferType_t* type,
+            parseAttrs_t* parse, inferType_t* type, inferSType_t* stype,
             inferShape_t* shape, mutateInputs_t* mutate) {
     CustomOp &op = Registry<CustomOp>::get()->get(idx);
     *name = op.name;
     *parse = op.parse_attrs;
     *type = op.infer_type;
+    *stype = op.infer_storage_type;
     *shape = op.infer_shape;
     *mutate = op.mutate_inputs;
     *isSGop = op.isSGop;
@@ -1136,6 +1230,43 @@ extern "C" {
     return retval;
   }
 
+  /*! \brief returns status of calling inferSType function for operator from 
library */
+#if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
+  __declspec(dllexport) int __cdecl
+#else
+  int
+#endif
+  _opCallInferSType(inferSType_t inferSType, const char* const* keys,
+                   const char* const* vals, int num,
+                   int* instypes, int num_in, int* outstypes, int num_out) {
+    // create map of attributes from list
+    std::map<std::string, std::string> attrs;
+    for (int i = 0; i < num; i++) {
+      attrs[std::string(keys[i])] = std::string(vals[i]);
+    }
+
+    // create a vector of types for inputs
+    std::vector<int> in_stypes(num_in);
+    for (int i = 0; i < num_in; i++) {
+      in_stypes[i] = instypes[i];
+    }
+
+    // create a vector of types for outputs
+    std::vector<int> out_stypes(num_out, -1);
+
+    int retval = inferSType(attrs, in_stypes, out_stypes);
+
+    if (!retval)
+      return retval;
+
+    // copy output storage types
+    for (int i = 0; i < num_out; i++) {
+      outstypes[i] = out_stypes[i];
+    }
+
+    return retval;
+  }
+
   /*! \brief returns status of calling Forward/Backward function for operator 
from library */
 #if defined(_WIN32) || defined(_WIN64) || defined(__WINDOWS__)
   __declspec(dllexport) int __cdecl
@@ -1148,7 +1279,12 @@ extern "C" {
                   const int64_t** outshapes, int* outdims, void** outdata, 
int* outtypes,
                   size_t* outIDs, const char** outdev_type, int* outdev_id, 
int num_out,
                   xpu_malloc_t cpu_malloc, void* cpu_alloc,
-                  xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream) 
{
+                  xpu_malloc_t gpu_malloc, void* gpu_alloc, void* cuda_stream,
+                  sparse_malloc_t sparse_malloc, void* sparse_alloc,
+                  int* instypes, int* outstypes, void** in_indices, void** 
out_indices,
+                  void** in_indptr, void** out_indptr,
+                  int64_t* in_indices_shapes, int64_t* out_indices_shapes,
+                  int64_t* in_indptr_shapes, int64_t* out_indptr_shapes) {
     // create map of attributes from list
     std::map<std::string, std::string> attrs;
     for (int i = 0; i < num; i++) {
@@ -1157,20 +1293,59 @@ extern "C" {
 
     // create a vector of tensors for inputs
     std::vector<MXTensor> inputs(num_in);
+    // create a vector for sparse inputs
+    std::vector<MXSparse> in_sparse(num_in);
+
     for (int i = 0; i < num_in; i++) {
-      inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], 
indims[i],
-                          inIDs[i], {indev_type[i], indev_id[i]});
+      // Dense representation.
+      if (instypes[i] == 0) {
+        inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], 
indims[i],
+                            inIDs[i], {indev_type[i], indev_id[i]}, 
kDefaultStorage);
+      } else {
+        // Sparse representation.
+        MXStorageType type;
+        if (instypes[i] == 1) {
+          type = kRowSparseStorage;
+          in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], 
in_indices_shapes[i]);
+        } else {
+          type = kCSRStorage;
+          in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i],
+                           in_indices_shapes[i], in_indptr[i], 
in_indptr_shapes[i]);
+        }
+        inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]), 
(MXDType)intypes[i],
+                            inshapes[i], indims[i], inIDs[i], {indev_type[i], 
indev_id[i]}, type);
+      }
     }
 
     // create a vector of tensors for outputs
     std::vector<MXTensor> outputs(num_out);
+    std::vector<MXSparse> out_sparse(num_out);
+
     for (int i = 0; i < num_out; i++) {
-      outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], 
outdims[i],
-                           outIDs[i], {outdev_type[i], outdev_id[i]});
+      // Dense representation.
+      if (outstypes[i] == 0) {
+        outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], 
outdims[i],
+                            outIDs[i], {outdev_type[i], outdev_id[i]}, 
kDefaultStorage);
+      } else {
+        // Sparse representation.
+        MXStorageType type;
+        if (outstypes[i] == 1) {
+          type = kRowSparseStorage;
+          out_sparse[i].set(outdata[i], outshapes[i], outdims[i],
+                            out_indices[i], out_indices_shapes[i]);
+        } else {
+          type = kCSRStorage;
+          out_sparse[i].set(outdata[i], outshapes[i], outdims[i], 
out_indices[i],
+                            out_indices_shapes[i], out_indptr[i], 
out_indptr_shapes[i]);
+        }
+        outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]), 
(MXDType)outtypes[i],
+                            outshapes[i], outdims[i], outIDs[i], 
{outdev_type[i],
+                            outdev_id[i]}, type);
+      }
     }
 
-    OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, cuda_stream);
-
+    OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
+                   cuda_stream, sparse_malloc, sparse_alloc);
     return fcomp(attrs, inputs, outputs, res);
   }
 
@@ -1239,22 +1414,69 @@ extern "C" {
                           const int64_t** outshapes, int* outdims, void** 
outdata, int* outtypes,
                           size_t* outIDs, const char** outdev_type, int* 
outdev_id, int num_out,
                           xpu_malloc_t cpu_malloc, void* cpu_alloc,
-                          xpu_malloc_t gpu_malloc, void* gpu_alloc, void* 
stream) {
+                          xpu_malloc_t gpu_malloc, void* gpu_alloc, void* 
stream,
+                          sparse_malloc_t sparse_malloc, void* sparse_alloc,
+                          int* instypes, int* outstypes, void** in_indices, 
void** out_indices,
+                          void** in_indptr, void** out_indptr,
+                          int64_t* in_indices_shapes, int64_t* 
out_indices_shapes,
+                          int64_t* in_indptr_shapes, int64_t* 
out_indptr_shapes) {
     // create a vector of tensors for inputs
     std::vector<MXTensor> inputs(num_in);
+    // create a vector for sparse inputs
+    std::vector<MXSparse> in_sparse(num_in);
+
     for (int i = 0; i < num_in; i++) {
-      inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], 
indims[i],
-                          inIDs[i], {indev_type[i], indev_id[i]});
+      if (instypes[i] == 0) {
+        // Dense representation.
+        inputs[i].setTensor(indata[i], (MXDType)intypes[i], inshapes[i], 
indims[i],
+                            inIDs[i], {indev_type[i], indev_id[i]}, 
kDefaultStorage);
+      } else {
+        // Sparse representation.
+        MXStorageType type;
+        if (instypes[i] == 1) {
+          type = kRowSparseStorage;
+          in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i], 
in_indices_shapes[i]);
+        } else {
+          type = kCSRStorage;
+          in_sparse[i].set(indata[i], inshapes[i], indims[i], in_indices[i],
+                           in_indices_shapes[i], in_indptr[i], 
in_indptr_shapes[i]);
+        }
+        inputs[i].setTensor(reinterpret_cast<void*>(&in_sparse[i]), 
(MXDType)intypes[i],
+                            inshapes[i], indims[i], inIDs[i], {indev_type[i],
+                            indev_id[i]}, type);
+      }
     }
 
     // create a vector of tensors for outputs
     std::vector<MXTensor> outputs(num_out);
+    // create a vector for sparse outputs
+    std::vector<MXSparse> out_sparse(num_out);
+
     for (int i = 0; i < num_out; i++) {
-      outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], 
outdims[i],
-                           outIDs[i], {outdev_type[i], outdev_id[i]});
+      if (outstypes[i] == 0) {
+        // Dense representation.
+        outputs[i].setTensor(outdata[i], (MXDType)outtypes[i], outshapes[i], 
outdims[i],
+                             outIDs[i], {outdev_type[i], outdev_id[i]}, 
kDefaultStorage);
+      } else {
+        // Sparse representation.
+        MXStorageType type;
+        if (outstypes[i] == 1) {
+          type = kRowSparseStorage;
+          out_sparse[i].set(outdata[i], outshapes[i], outdims[i], 
out_indices[i],
+                            out_indices_shapes[i]);
+        } else {
+          type = kCSRStorage;
+          out_sparse[i].set(outdata[i], outshapes[i], outdims[i], 
out_indices[i],
+                            out_indices_shapes[i], out_indptr[i], 
out_indptr_shapes[i]);
+        }
+        outputs[i].setTensor(reinterpret_cast<void*>(&out_sparse[i]), 
(MXDType)outtypes[i],
+                             outshapes[i], outdims[i], outIDs[i], 
{outdev_type[i],
+                             outdev_id[i]}, type);
+      }
     }
 
-    OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc, stream);
+    OpResource res(cpu_malloc, cpu_alloc, gpu_malloc, gpu_alloc,
+                   stream, sparse_malloc, sparse_alloc);
 
     CustomStatefulOp* op_ptr = reinterpret_cast<CustomStatefulOp*>(state_op);
     if (is_forward) {
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index db0e262..fe00a9a 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -114,7 +114,7 @@ void CustomFComputeDispatcher(const std::string op_name,
                               const std::vector<OpReqType>& req,
                               const std::vector<NDArray>& outputs) {
   std::vector<void*> in_data, out_data;
-  std::vector<const int64_t *> in_shapes, out_shapes;
+  std::vector<const int64_t*> in_shapes, out_shapes;
   std::vector<int> in_dims, out_dims;
   std::vector<int> in_types, out_types;
   std::vector<size_t> in_verIDs, out_verIDs;
@@ -122,6 +122,13 @@ void CustomFComputeDispatcher(const std::string op_name,
   std::vector<int> in_dev_id, out_dev_id;
   std::vector<NDArray> conv_mkl;  // converted NDArrays from MKLDNN format
 
+  // Extra data for sparse inputs and outputs.
+  std::vector<int> in_stypes(inputs.size(), 0), out_stypes(outputs.size(), 0);
+  std::vector<void*> in_indices(inputs.size(), nullptr), 
out_indices(outputs.size(), nullptr);
+  std::vector<void*> in_indptr(inputs.size(), nullptr), 
out_indptr(outputs.size(), nullptr);
+  std::vector<int64_t> in_indices_shapes(inputs.size(), 0), 
out_indices_shapes(outputs.size(), 0);
+  std::vector<int64_t> in_indptr_shapes(inputs.size(), 0), 
out_indptr_shapes(outputs.size(), 0);
+
   // convert inputs/outpus NDArray to C types to be passed to lib_api.h
   for (size_t i = 0; i < inputs.size(); i++) {
     NDArray const* in_nd = &(inputs[i]);
@@ -141,7 +148,19 @@ void CustomFComputeDispatcher(const std::string op_name,
     in_verIDs.push_back(in_nd->version());
     const char* ctx_str = in_nd->ctx().dev_mask() == Context::kCPU ? "cpu" : 
"gpu";
     in_dev_type.push_back(ctx_str);
+
     in_dev_id.push_back(in_nd->ctx().real_dev_id());
+    if (inputs[i].storage_type() == mxnet::kRowSparseStorage) {
+      in_stypes[i] = 1;
+      in_indices[i] = inputs[i].aux_data(rowsparse::kIdx).dptr_;
+      in_indices_shapes[i] = inputs[i].aux_shape(rowsparse::kIdx).Size();
+    } else if (inputs[i].storage_type() == mxnet::kCSRStorage) {
+      in_stypes[i] = 2;
+      in_indices[i] = inputs[i].aux_data(csr::kIdx).dptr_;
+      in_indptr[i] = inputs[i].aux_data(csr::kIndPtr).dptr_;
+      in_indices_shapes[i] = inputs[i].aux_shape(csr::kIdx).Size();
+      in_indptr_shapes[i] = inputs[i].aux_shape(csr::kIndPtr).Size();
+    }
   }
 
   for (size_t i = 0; i < outputs.size(); i++) {
@@ -153,6 +172,18 @@ void CustomFComputeDispatcher(const std::string op_name,
     const char* ctx_str = outputs[i].ctx().dev_mask() == Context::kCPU ? "cpu" 
: "gpu";
     out_dev_type.push_back(ctx_str);
     out_dev_id.push_back(outputs[i].ctx().real_dev_id());
+
+    if (outputs[i].storage_type() == mxnet::kRowSparseStorage) {
+      out_stypes[i] = 1;
+      out_indices[i] = outputs[i].aux_data(rowsparse::kIdx).dptr_;
+      out_indices_shapes[i] = outputs[i].aux_shape(rowsparse::kIdx).Size();
+    } else if (outputs[i].storage_type() == mxnet::kCSRStorage) {
+      out_stypes[i] = 2;
+      out_indices[i] = outputs[i].aux_data(csr::kIdx).dptr_;
+      out_indptr[i] = outputs[i].aux_data(csr::kIndPtr).dptr_;
+      out_indices_shapes[i] = outputs[i].aux_shape(csr::kIdx).Size();
+      out_indptr_shapes[i] = outputs[i].aux_shape(csr::kIndPtr).Size();
+    }
   }
 
   // get memory resource and mxnet backend streams
@@ -173,6 +204,24 @@ void CustomFComputeDispatcher(const std::string op_name,
     return workspace.dptr_;
   };
 
+  // create lambda that allocates memory for sparse and
+  // returns allocated arrays for data, indices and indptr.
+  auto sparse_alloc = [&](int index, int indices_len, int idxptr_len,
+                           void** data, int64_t** indices, int64_t** indptr) {
+    if (idxptr_len == 0) {
+      // Row Sparse
+      outputs[index].CheckAndAlloc({mshadow::Shape1(indices_len)});
+      *data = outputs[index].data().dptr_;
+      *indices = 
reinterpret_cast<int64_t*>(outputs[index].aux_data(rowsparse::kIdx).dptr_);
+    } else {
+      // CSR
+      outputs[index].CheckAndAlloc({mshadow::Shape1(idxptr_len), 
mshadow::Shape1(indices_len)});
+      *data = outputs[index].data().dptr_;
+      *indices = 
reinterpret_cast<int64_t*>(outputs[index].aux_data(csr::kIdx).dptr_);
+      *indptr = 
reinterpret_cast<int64_t*>(outputs[index].aux_data(csr::kIndPtr).dptr_);
+    }
+  };
+
   // create lambda without captures so that we can cast it to function pointer
   // lambda with captures cannot be cast to function pointer and pass to 
lib_api.h
   // this needs to be a lambda function so that we can do the decltype cast
@@ -189,6 +238,13 @@ void CustomFComputeDispatcher(const std::string op_name,
     return static_cast<void*>((*gpualloc)(size));
   };
 
+  typedef decltype(sparse_alloc) alloc_type_sparse;
+  auto sparse_malloc = [](void* _sparse_alloc, int index, int indices_len, int 
idxptr_len,
+                           void** data, int64_t** indices, int64_t** indptr) {
+    alloc_type_sparse* sparsealloc = 
static_cast<alloc_type_sparse*>(_sparse_alloc);
+    (*sparsealloc)(index, indices_len, idxptr_len, data, indices, indptr);
+  };
+
   // get actual cudaStream_t out of mxnet gpu stream and pass to lib_api.h
   void *cuda_stream = nullptr;
 #if MXNET_USE_CUDA
@@ -208,13 +264,18 @@ void CustomFComputeDispatcher(const std::string op_name,
       attr_keys.push_back(kv.first.c_str());
       attr_vals.push_back(kv.second.c_str());
     }
+
     // call fcompute function
     CHECK(callFComp(fcomp_fp, attr_keys.data(), attr_vals.data(), 
attr_keys.size(),
                     in_shapes.data(), in_dims.data(), in_data.data(), 
in_types.data(),
                     in_verIDs.data(), in_dev_type.data(), in_dev_id.data(), 
in_data.size(),
                     out_shapes.data(), out_dims.data(), out_data.data(), 
out_types.data(),
                     out_verIDs.data(), out_dev_type.data(), out_dev_id.data(), 
out_data.size(),
-                    cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, 
cuda_stream))
+                    cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, 
cuda_stream,
+                    sparse_malloc, &sparse_alloc, in_stypes.data(), 
out_stypes.data(),
+                    in_indices.data(), out_indices.data(), in_indptr.data(), 
out_indptr.data(),
+                    in_indices_shapes.data(), out_indices_shapes.data(),
+                    in_indptr_shapes.data(), out_indptr_shapes.data()))
       << "Error calling FCompute for custom operator '" << op_name << "'";
   }
 
@@ -233,7 +294,12 @@ void CustomFComputeDispatcher(const std::string op_name,
                             out_shapes.data(), out_dims.data(), 
out_data.data(), out_types.data(),
                             out_verIDs.data(), out_dev_type.data(), 
out_dev_id.data(),
                             out_data.size(),
-                            cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, 
cuda_stream))
+                            cpu_malloc, &cpu_alloc, gpu_malloc, &gpu_alloc, 
cuda_stream,
+                            sparse_malloc, &sparse_alloc, in_stypes.data(), 
out_stypes.data(),
+                            in_indices.data(), out_indices.data(),
+                            in_indptr.data(), out_indptr.data(),
+                            in_indices_shapes.data(), 
out_indices_shapes.data(),
+                            in_indptr_shapes.data(), out_indptr_shapes.data()))
       << "Error calling FStatefulCompute for custom operator '" << op_name << 
"'";
   }
 }
@@ -272,6 +338,9 @@ int MXLoadLib(const char *path) {
   opCallInferType_t callInferType =
     get_func<opCallInferType_t>(lib, 
const_cast<char*>(MXLIB_OPCALLINFERTYPE_STR));
 
+  opCallInferSType_t callInferSType =
+    get_func<opCallInferSType_t>(lib, 
const_cast<char*>(MXLIB_OPCALLINFERSTYPE_STR));
+
   opCallFComp_t callFComp =
     get_func<opCallFComp_t>(lib, const_cast<char*>(MXLIB_OPCALLFCOMP_STR));
 
@@ -306,6 +375,7 @@ int MXLoadLib(const char *path) {
     // function pointers holding implementation from custom library
     parseAttrs_t parse_fp = nullptr;
     inferType_t type_fp = nullptr;
+    inferSType_t stype_fp = nullptr;
     inferShape_t shape_fp = nullptr;
     // optional attributes
     mutateInputs_t mutate_fp = nullptr;
@@ -322,7 +392,7 @@ int MXLoadLib(const char *path) {
              &forward_ctx, &forward_fcomp, &forward_count,
              &backward_ctx, &backward_fcomp, &backward_count,
              &createop_ctx, &createop_fp, &createop_count,
-             &parse_fp, &type_fp, &shape_fp, &mutate_fp);
+             &parse_fp, &type_fp, &stype_fp, &shape_fp, &mutate_fp);
 
     // construct maps of context to forward/backward custom library function
     std::unordered_map<std::string, fcomp_t> forward_ctx_map;
@@ -583,12 +653,39 @@ int MXLoadLib(const char *path) {
                                 DispatchMode* dispatch_mode,
                                 std::vector<int>* in_stypes,
                                 std::vector<int>* out_stypes) {
-      // TODO(ziyimu): remove this dense enforce check after supporting sparse 
tensor
-      CHECK(mxnet::common::ContainsOnlyStorage(*in_stypes, 
mxnet::kDefaultStorage))
-      << "Error input tensors are not dense for custom operator '" << name_str 
<< "'";
-      // set outputs as dense
-      return op::storage_type_assign(out_stypes, mxnet::kDefaultStorage,
-                                     dispatch_mode, DispatchMode::kFComputeEx);
+      if (stype_fp == nullptr) {
+        // InferSType is not defineid in customized lib.
+        CHECK(mxnet::common::ContainsOnlyStorage(*in_stypes, 
mxnet::kDefaultStorage))
+        << "Error input tensors are not dense for custom operator '" << 
name_str << "'";
+        // set outputs as dense
+        return op::storage_type_assign(out_stypes, mxnet::kDefaultStorage,
+                                       dispatch_mode, 
DispatchMode::kFComputeEx);
+      } else {
+        // InferSType is defined in customized lib.
+        // convert attributes to vector of char*
+        std::vector<const char*> attr_keys, attr_vals;
+        for (auto kv : attrs.dict) {
+          attr_keys.push_back(kv.first.c_str());
+          attr_vals.push_back(kv.second.c_str());
+        }
+        // copy input types from in_stype
+        std::vector<int> instypes(*in_stypes);
+
+        // output types will be populated by inferType function
+        std::vector<int> outstypes(out_stypes->size());
+        CHECK(callInferSType(stype_fp, attr_keys.data(), attr_vals.data(), 
attr_keys.size(),
+                             instypes.data(), in_stypes->size(),
+                             outstypes.data(), out_stypes->size()))
+        << "Error calling InferSType for custom operator '" << name_str << "'";
+
+        // copy and assign output storage types from custom op to MXNet memory.
+        for (size_t i = 0; i < out_stypes->size(); i++) {
+          STORAGE_TYPE_ASSIGN_CHECK(*out_stypes, i, outstypes[i]);
+        }
+        // assign dispatch mode
+        DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, 
DispatchMode::kFComputeEx);
+        return true;
+      }
     };
 
     // FGradient register lambda
@@ -698,8 +795,8 @@ int MXLoadLib(const char *path) {
       regOp.set_num_inputs(num_inputs);
       regOp.set_num_outputs(num_outputs);
       regOp.set_attr<nnvm::FInferType>("FInferType", infer_type, plevel);
-      regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_shape, plevel);
       regOp.set_attr<FInferStorageType>("FInferStorageType", 
infer_storage_type, plevel);
+      regOp.set_attr<mxnet::FInferShape>("FInferShape", infer_shape, plevel);
       regOp.set_attr<FResourceRequest>("FResourceRequest", resc_req, plevel);
       // optionally add fmutate inputs if user specified a function
       if (mutate_fp != nullptr)

Reply via email to