This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 09c71bf  Add Sparse NDArray support for Scala (#15378)
09c71bf is described below

commit 09c71bf3144b09a28b9d09d33703a3dcbf4ca9a5
Author: Lanking <lanking...@live.com>
AuthorDate: Mon Jul 8 11:48:30 2019 -0700

    Add Sparse NDArray support for Scala (#15378)
    
    * add Sparse Support
    
    * add imperative invoke sparse support
    
    * add retain method and comments
    
    * add getData method
    
    * add Sparse NDIter test
    
    * remove debug line
---
 .../src/main/scala/org/apache/mxnet/DType.scala    |  17 +-
 .../src/main/scala/org/apache/mxnet/Executor.scala |   9 +-
 .../src/main/scala/org/apache/mxnet/LibInfo.scala  |  27 ++-
 .../src/main/scala/org/apache/mxnet/NDArray.scala  |  65 ++++++-
 .../main/scala/org/apache/mxnet/SparseFormat.scala |  25 +++
 .../scala/org/apache/mxnet/SparseNDArray.scala     | 196 +++++++++++++++++++++
 .../test/scala/org/apache/mxnet/NDArraySuite.scala |  16 ++
 .../org/apache/mxnet/SparseNDArraySuite.scala      |  93 ++++++++++
 .../main/native/org_apache_mxnet_native_c_api.cc   |  75 +++++++-
 .../main/native/org_apache_mxnet_native_c_api.h    |  48 ++++-
 10 files changed, 543 insertions(+), 28 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
index f3a8e8e..1d5cc28 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
@@ -24,26 +24,17 @@ object DType extends Enumeration {
   val Float16 = Value(2, "float16")
   val UInt8 = Value(3, "uint8")
   val Int32 = Value(4, "int32")
+  val Int8 = Value(5, "int8")
+  val Int64 = Value(6, "int64")
   val Unknown = Value(-1, "unknown")
   private[mxnet] def numOfBytes(dtype: DType): Int = {
     dtype match {
-      case DType.UInt8 => 1
+      case DType.UInt8 | DType.Int8 => 1
       case DType.Int32 => 4
       case DType.Float16 => 2
       case DType.Float32 => 4
-      case DType.Float64 => 8
+      case DType.Float64 | DType.Int64 => 8
       case DType.Unknown => 0
     }
   }
-  private[mxnet] def getType(dtypeStr: String): DType = {
-    dtypeStr match {
-      case "UInt8" => DType.UInt8
-      case "Int32" => DType.Int32
-      case "Float16" => DType.Float16
-      case "Float32" => DType.Float32
-      case "Float64" => DType.Float64
-      case _ => throw new IllegalArgumentException(
-        s"DType: $dtypeStr not found! please set it in DType.scala")
-    }
-  }
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
index b0fae0f..6365f9c 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
@@ -159,7 +159,14 @@ class Executor private[mxnet](private[mxnet] val handle: 
ExecutorHandle,
   private def getOutputs: Array[NDArray] = {
     val ndHandles = ArrayBuffer[NDArrayHandle]()
     checkCall(_LIB.mxExecutorOutputs(handle, ndHandles))
-    ndHandles.toArray.map(new NDArray(_, addToCollector = false))
+    ndHandles.toArray.map(ele => {
+        val nd = new NDArray(ele, addToCollector = false)
+        if (nd.isSparse) {
+          nd.asInstanceOf[SparseNDArray]
+        }
+        nd
+      }
+    )
   }
 
   /**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
index 640ecf5..0ee6476 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
@@ -31,13 +31,14 @@ private[mxnet] class LibInfo {
   @native def mxListAllOpNames(names: ListBuffer[String]): Int
   @native def nnGetOpHandle(opName: String, opHandle: RefLong): Int
   // NDArray
-  @native def mxImperativeInvoke(creator: FunctionHandle,
+  @native def mxImperativeInvokeEx(creator: FunctionHandle,
                                  inputs: Array[NDArrayHandle],
                                  outputsGiven: Array[NDArrayHandle],
                                  outputs: ArrayBuffer[NDArrayHandle],
                                  numParams: Int,
                                  paramKeys: Array[String],
-                                 paramVals: Array[String]): Int
+                                 paramVals: Array[String],
+                                 outStype: ArrayBuffer[Int]): Int
   @native def mxNDArrayFree(handle: NDArrayHandle): Int
   @native def mxNDArrayCreateNone(out: NDArrayHandleRef): Int
   @native def mxNDArrayCreateEx(shape: Array[Int],
@@ -47,6 +48,20 @@ private[mxnet] class LibInfo {
                                 delayAlloc: Int,
                                 dtype: Int,
                                 out: NDArrayHandleRef): Int
+  // scalastyle:off parameterNum
+  @native def mxNDArrayCreateSparseEx(storageType: Int,
+                                      shape: Array[Int],
+                                      ndim: Int,
+                                      devType: Int,
+                                      devId: Int,
+                                      delayAlloc: Int,
+                                      dtype: Int,
+                                      numAux: Int,
+                                      auxTypes: Array[Int],
+                                      auxNdims: Array[Int],
+                                      auxShapes: Array[Int],
+                                      out: NDArrayHandleRef): Int
+  // scalastyle:on parameterNum
   @native def mxNDArrayWaitAll(): Int
   @native def mxNDArrayWaitToRead(handle: NDArrayHandle): Int
   @native def mxListFunctions(functions: ListBuffer[FunctionHandle]): Int
@@ -76,6 +91,9 @@ private[mxnet] class LibInfo {
   @native def mxNDArrayGetShape(handle: NDArrayHandle,
                                 ndim: MXUintRef,
                                 data: ArrayBuffer[Int]): Int
+  @native def mxNDArraySyncCopyFromNDArray(handleDst: NDArrayHandle,
+                                           handleSrc: NDArrayHandle,
+                                           locator: Int): Int
   @native def mxNDArraySyncCopyToCPU(handle: NDArrayHandle,
                                      data: Array[Byte],
                                      size: Int): Int
@@ -105,10 +123,15 @@ private[mxnet] class LibInfo {
   @native def mxNDArraySave(fname: String,
                             handles: Array[NDArrayHandle],
                             keys: Array[String]): Int
+  @native def mxNDArrayGetDataNDArray(handle: NDArrayHandle, out: 
NDArrayHandleRef): Int
+  @native def mxNDArrayGetAuxNDArray(handle: NDArrayHandle,
+                                     location: Int,
+                                     out: NDArrayHandleRef): Int
   @native def mxNDArrayGetContext(handle: NDArrayHandle, devTypeId: RefInt, 
devId: RefInt): Int
   @native def mxNDArraySaveRawBytes(handle: NDArrayHandle, buf: 
ArrayBuffer[Byte]): Int
   @native def mxNDArrayLoadFromRawBytes(bytes: Array[Byte], handle: 
NDArrayHandleRef): Int
   @native def mxNDArrayGetDType(handle: NDArrayHandle, dtype: RefInt): Int
+  @native def mxNDArrayGetStorageType(handle: NDArrayHandle, stype: RefInt): 
Int
 
   // KVStore Server
   @native def mxInitPSEnv(keys: Array[String], values: Array[String]): Int
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 4088801..1b7b31b 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -21,7 +21,8 @@ import java.nio.{ByteBuffer, ByteOrder}
 
 import org.apache.mxnet.Base._
 import org.apache.mxnet.DType.DType
-import org.apache.mxnet.MX_PRIMITIVES.{MX_PRIMITIVE_TYPE}
+import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE
+import org.apache.mxnet.SparseFormat.SparseFormat
 import org.slf4j.LoggerFactory
 
 import scala.collection.mutable
@@ -113,10 +114,22 @@ object NDArray extends NDArrayBase {
       }
 
     val outputs = ArrayBuffer.empty[NDArrayHandle]
-    checkCall(_LIB.mxImperativeInvoke(function.handle, 
ndArgs.map(_.handle).toArray, outputVars,
-      outputs, updatedKwargs.size, updatedKwargs.keys.toArray, 
updatedKwargs.values.toArray))
+    val outStypes = ArrayBuffer.empty[Int]
+    checkCall(_LIB.mxImperativeInvokeEx(function.handle,
+      ndArgs.map(_.handle).toArray,
+      outputVars,
+      outputs,
+      updatedKwargs.size,
+      updatedKwargs.keys.toArray,
+      updatedKwargs.values.toArray,
+      outStypes))
     new NDArrayFuncReturn(Option(oriOutputs).getOrElse {
-      val outputArrs = outputs.map(new NDArray(_)).toArray
+      val outputArrs = (outputs zip outStypes).map(
+        ele => ele._2 match {
+          case 0 => new NDArray(ele._1)
+          case _ => new SparseNDArray(ele._1)
+        }
+        ).toArray
       addDependency(ndArgs.toArray, outputArrs)
       outputArrs
     })
@@ -943,6 +956,12 @@ class NDArray private[mxnet](private[mxnet] val handle: 
NDArrayHandle,
     DType(mxDtype.value)
   }
 
+  val sparseFormat: SparseFormat = {
+    val mxSF = new RefInt
+    checkCall(_LIB.mxNDArrayGetStorageType(handle, mxSF))
+    SparseFormat(mxSF.value)
+  }
+
   /**
    * Return a copied numpy array of current array with specified type.
    * @param dtype Desired type of result array.
@@ -1309,6 +1328,30 @@ class NDArray private[mxnet](private[mxnet] val handle: 
NDArrayHandle,
     if (this.context == context) this else this.copyTo(context)
   }
 
+  /**
+    * check if NDArray is SparseNDArray
+    * @return Boolean
+    */
+  def isSparse: Boolean = {
+      this.sparseFormat.id != 0
+  }
+
+  /**
+    * Convert a NDArray to SparseNDArray
+    *
+    * @param sfOption the target sparse type
+    * @return SparseNDArray
+    */
+  def toSparse(sfOption : Option[SparseFormat] = None): SparseNDArray = {
+    val sf = sfOption.getOrElse(SparseFormat.ROW_SPARSE)
+    if (sf.id == 0) throw new IllegalArgumentException("Require Sparse")
+    if (isSparse && sfOption.isEmpty) {
+        this.asInstanceOf[SparseNDArray]
+    } else {
+      NDArray.api.cast_storage(this, 
sf.toString).head.asInstanceOf[SparseNDArray]
+    }
+  }
+
   override def equals(o: Any): Boolean = o match {
     case that: NDArray =>
       that != null && that.shape == this.shape && 
that.toArray.sameElements(this.toArray)
@@ -1479,6 +1522,7 @@ private[mxnet] class NDArrayInternal (private val 
internal: Array[Byte], private
       case DType.Float32 => units.map(wrapBytes(_).getFloat.toDouble)
       case DType.Float64 => units.map(wrapBytes(_).getDouble)
       case DType.Int32 => units.map(wrapBytes(_).getInt.toDouble)
+      case DType.Int64 => units.map(wrapBytes(_).getLong.toDouble)
       case DType.UInt8 => internal.map(_.toDouble)
     }
   }
@@ -1488,6 +1532,7 @@ private[mxnet] class NDArrayInternal (private val 
internal: Array[Byte], private
       case DType.Float32 => units.map(wrapBytes(_).getFloat)
       case DType.Float64 => units.map(wrapBytes(_).getDouble.toFloat)
       case DType.Int32 => units.map(wrapBytes(_).getInt.toFloat)
+      case DType.Int64 => units.map(wrapBytes(_).getLong.toFloat)
       case DType.UInt8 => internal.map(_.toFloat)
     }
   }
@@ -1497,15 +1542,27 @@ private[mxnet] class NDArrayInternal (private val 
internal: Array[Byte], private
       case DType.Float32 => units.map(wrapBytes(_).getFloat.toInt)
       case DType.Float64 => units.map(wrapBytes(_).getDouble.toInt)
       case DType.Int32 => units.map(wrapBytes(_).getInt)
+      case DType.Int64 => units.map(wrapBytes(_).getLong.toInt)
       case DType.UInt8 => internal.map(_.toInt)
     }
   }
+  def toLongArray: Array[Long] = {
+    require(dtype != DType.Float16, "Currently cannot convert float16 to 
native numerical types")
+    dtype match {
+      case DType.Float32 => units.map(wrapBytes(_).getFloat.toLong)
+      case DType.Float64 => units.map(wrapBytes(_).getDouble.toLong)
+      case DType.Int32 => units.map(wrapBytes(_).getInt.toLong)
+      case DType.Int64 => units.map(wrapBytes(_).getLong)
+      case DType.UInt8 => internal.map(_.toLong)
+    }
+  }
   def toByteArray: Array[Byte] = {
     require(dtype != DType.Float16, "Currently cannot convert float16 to 
native numerical types")
     dtype match {
       case DType.Float16 | DType.Float32 => 
units.map(wrapBytes(_).getFloat.toByte)
       case DType.Float64 => units.map(wrapBytes(_).getDouble.toByte)
       case DType.Int32 => units.map(wrapBytes(_).getInt.toByte)
+      case DType.Int64 => units.map(wrapBytes(_).getLong.toByte)
       case DType.UInt8 => internal.clone()
     }
   }
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/SparseFormat.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/SparseFormat.scala
new file mode 100644
index 0000000..acb0c0f
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/SparseFormat.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+object SparseFormat extends Enumeration {
+  type SparseFormat = Value
+  val DEFAULT = Value(0, "default")
+  val ROW_SPARSE = Value(1, "row_sparse")
+  val CSR = Value(2, "csr")
+}
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala
new file mode 100644
index 0000000..f3fe638
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.apache.mxnet.Base.{NDArrayHandle, NDArrayHandleRef, checkCall, _LIB}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet.SparseFormat.SparseFormat
+
+object SparseNDArray {
+  /**
+    * Create a Compressed Sparse Row Storage (CSR) Format Matrix
+    * @param data the data to feed
+    * @param indices The indices array stores the column index for each 
non-zero element in data
+    * @param indptr The indptr array is what will help identify the rows where 
the data appears
+    * @param shape the shape of CSR NDArray to be created
+    * @param ctx the context of this NDArray
+    * @return SparseNDArray
+    */
+  def csrMatrix(data: Array[Float], indices: Array[Float],
+                indptr: Array[Float], shape: Shape, ctx: Context): 
SparseNDArray = {
+    val fmt = SparseFormat.CSR
+    val dataND = NDArray.array(data, Shape(data.length), ctx)
+    val indicesND = NDArray.array(indices, Shape(indices.length), 
ctx).asType(DType.Int64)
+    val indptrND = NDArray.array(indptr, Shape(indptr.length), 
ctx).asType(DType.Int64)
+    val dTypes = Array(indptrND.dtype, indicesND.dtype)
+    val shapes = Array(indptrND.shape, indicesND.shape)
+    val handle =
+      newAllocHandle(fmt, shape, ctx, false, DType.Float32, dTypes, shapes)
+    checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, dataND.handle, -1))
+    checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, indptrND.handle, 0))
+    checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, indicesND.handle, 1))
+    new SparseNDArray(handle)
+  }
+
+  /**
+    * RowSparseNDArray stores the matrix in row sparse format,
+    * which is designed for arrays of which most row slices are all zeros
+    * @param data Any Array(Array(... Array(Float)))
+    * @param indices the indices to store the data
+    * @param shape shape of the NDArray
+    * @param ctx Context
+    * @return SparseNDArray
+    */
+  def rowSparseArray(data: Array[_], indices: Array[Float],
+                     shape: Shape, ctx: Context): SparseNDArray = {
+    val dataND = NDArray.toNDArray(data)
+    val indicesND = NDArray.array(indices, Shape(indices.length), 
ctx).asType(DType.Int64)
+    rowSparseArray(dataND, indicesND, shape, ctx)
+  }
+
+  /**
+    * RowSparseNDArray stores the matrix in row sparse format,
+    * which is designed for arrays of which most row slices are all zeros
+    * @param data NDArray input
+    * @param indices in NDArray. Only DType.Int64 supported
+    * @param shape shape of the NDArray
+    * @param ctx Context
+    * @return
+    */
+  def rowSparseArray(data: NDArray, indices: NDArray,
+                     shape: Shape, ctx: Context): SparseNDArray = {
+    val fmt = SparseFormat.ROW_SPARSE
+    val handle = newAllocHandle(fmt, shape, ctx, false,
+      DType.Float32, Array(indices.dtype), Array(indices.shape))
+    checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, data.handle, -1))
+    checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, indices.handle, 0))
+    new SparseNDArray(handle)
+  }
+
+  def retain(sparseNDArray: SparseNDArray, indices: Array[Float]): 
SparseNDArray = {
+    if (sparseNDArray.sparseFormat == SparseFormat.CSR) {
+      throw new IllegalArgumentException("CSR not supported")
+    }
+    NDArray.genericNDArrayFunctionInvoke("_sparse_retain",
+      Seq(sparseNDArray, NDArray.toNDArray(indices))).head.toSparse()
+  }
+
+  private def newAllocHandle(stype : SparseFormat,
+                             shape: Shape,
+                             ctx: Context,
+                             delayAlloc: Boolean,
+                             dtype: DType = DType.Float32,
+                             auxDTypes: Array[DType],
+                             auxShapes: Array[Shape]) : NDArrayHandle = {
+    val hdl = new NDArrayHandleRef
+    checkCall(_LIB.mxNDArrayCreateSparseEx(
+      stype.id,
+      shape.toArray,
+      shape.length,
+      ctx.deviceTypeid,
+      ctx.deviceId,
+      if (delayAlloc) 1 else 0,
+      dtype.id,
+      auxDTypes.length,
+      auxDTypes.map(_.id),
+      auxShapes.map(_.length),
+      auxShapes.map(_.get(0)),
+      hdl)
+    )
+    hdl.value
+  }
+}
+
+/**
+  * Sparse NDArray is the child class of NDArray designed to hold the Sparse 
format
+  *
+  * <p> Currently, Rowsparse and CSR typed NDArray is supported. Most of the 
Operators
+  * will convert Sparse NDArray to dense. Basic operators like 
<code>add</code> will
+  * have optimization for sparse operattions</p>
+  * @param handle The pointer that SparseNDArray holds
+  * @param writable whether the NDArray is writable
+  */
+class SparseNDArray private[mxnet] (override private[mxnet] val handle: 
NDArrayHandle,
+                                    override val writable: Boolean = true)
+  extends NDArray(handle, writable) {
+
+  private lazy val dense: NDArray = toDense
+
+  override def toString: String = {
+    dense.toString
+  }
+
+  /**
+    * Convert a SparseNDArray to dense NDArray
+    * @return NDArray
+    */
+  def toDense: NDArray = {
+      NDArray.api.cast_storage(this, SparseFormat.DEFAULT.toString).head
+  }
+
+  override def toArray: Array[Float] = {
+    dense.toArray
+  }
+
+  override def at(idx: Int): NDArray = {
+    dense.at(idx)
+  }
+
+  override def slice(start: Int, end: Int): NDArray = {
+    NDArray.api.slice(this, Shape(start), Shape(end))
+  }
+
+  /**
+    * Get the Data portion from a Row Sparse NDArray
+    * @return NDArray
+    */
+  def getData: NDArray = {
+    require(this.sparseFormat == SparseFormat.ROW_SPARSE, "Not Supported for 
CSR")
+    val handle = new NDArrayHandleRef
+    _LIB.mxNDArrayGetDataNDArray(this.handle, handle)
+    new NDArray(handle.value, false)
+  }
+
+  /**
+    * Get the indptr Array from a CSR NDArray
+    * @return NDArray
+    */
+  def getIndptr: NDArray = {
+    require(this.sparseFormat == SparseFormat.CSR, "Not Supported for row 
sparse")
+    getAuxNDArray(0)
+  }
+
+  /**
+    * Get the indice Array
+    * @return NDArray
+    */
+  def getIndices: NDArray = {
+    if (this.sparseFormat == SparseFormat.ROW_SPARSE) {
+      getAuxNDArray(0)
+    } else {
+      getAuxNDArray(1)
+    }
+  }
+
+  private def getAuxNDArray(idx: Int): NDArray = {
+    val handle = new NDArrayHandleRef
+    checkCall(_LIB.mxNDArrayGetAuxNDArray(this.handle, idx, handle))
+    new NDArray(handle.value, false)
+  }
+
+}
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index c2ef641..82b9edc 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -45,6 +45,22 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll 
with Matchers {
     assert(ndones.toScalar === 1f)
   }
 
+  test("to sparse") {
+    val arr = Array(
+      Array(1f, 0f, 0f),
+      Array(0f, 3f, 0f),
+      Array(0f, 0f, 1f)
+    )
+    val nd = NDArray.toNDArray(arr)
+    assert(!nd.isSparse)
+    // row sparse
+    var ndSparse = nd.toSparse()
+    assert(ndSparse.getIndices.toArray sameElements Array(0f, 1f, 2f))
+    // csr
+    ndSparse = nd.toSparse(Some(SparseFormat.CSR))
+    assert(ndSparse.getIndptr.toArray sameElements Array(0f, 1f, 2f, 3f))
+  }
+
   test("to float 64 scalar") {
     val ndzeros = NDArray.zeros(Shape(1), dtype = DType.Float64)
     assert(ndzeros.toFloat64Scalar === 0d)
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/SparseNDArraySuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/SparseNDArraySuite.scala
new file mode 100644
index 0000000..f9968ef
--- /dev/null
+++ 
b/scala-package/core/src/test/scala/org/apache/mxnet/SparseNDArraySuite.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.apache.mxnet.io.NDArrayIter
+import org.scalatest.FunSuite
+import org.slf4j.LoggerFactory
+
+class SparseNDArraySuite  extends FunSuite {
+
+  private val logger = LoggerFactory.getLogger(classOf[SparseNDArraySuite])
+
+  test("create CSR NDArray") {
+    val data = Array(7f, 8f, 9f)
+    val indices = Array(0f, 2f, 1f)
+    val indptr = Array(0f, 2f, 2f, 3f)
+    val shape = Shape(3, 4)
+    val sparseND = SparseNDArray.csrMatrix(data, indices, indptr, shape, 
Context.cpu())
+    assert(sparseND.shape == Shape(3, 4))
+    assert(sparseND.toArray
+      sameElements Array(7.0f, 0.0f, 8.0f, 0.0f,
+                         0.0f, 0.0f, 0.0f, 0.0f,
+                         0.0f, 9.0f, 0.0f, 0.0f))
+    assert(sparseND.sparseFormat == SparseFormat.CSR)
+    assert(sparseND.getIndptr.toArray sameElements indptr)
+    assert(sparseND.getIndices.toArray sameElements indices)
+  }
+
+  test("create Row Sparse NDArray") {
+    val data = Array(
+      Array(1f, 2f),
+      Array(3f, 4f)
+    )
+    val indices = Array(1f, 4f)
+    val shape = Shape(6, 2)
+    val sparseND = SparseNDArray.rowSparseArray(data, indices, shape, 
Context.cpu())
+    assert(sparseND.sparseFormat == SparseFormat.ROW_SPARSE)
+    assert(sparseND.shape == Shape(6, 2))
+    assert(sparseND.at(1).toArray sameElements Array(1f, 2f))
+    assert(sparseND.getIndices.toArray sameElements indices)
+  }
+
+  test("Test retain") {
+    val arr = Array(
+      Array(1f, 2f),
+      Array(3f, 4f),
+      Array(5f, 6f)
+    )
+    val indices = Array(0f, 1f, 3f)
+    val rspIn = SparseNDArray.rowSparseArray(arr, indices, Shape(4, 2), 
Context.cpu())
+    val toRetain = Array(0f, 3f)
+    val rspOut = SparseNDArray.retain(rspIn, toRetain)
+    assert(rspOut.getData.toArray sameElements Array(1f, 2f, 5f, 6f))
+    assert(rspOut.getIndices.toArray sameElements Array(0f, 3f))
+  }
+
+  test("Test add") {
+    val nd = NDArray.array(Array(1f, 2f, 3f), 
Shape(3)).toSparse(Some(SparseFormat.ROW_SPARSE))
+    val nd2 = nd + nd
+    assert(nd2.isInstanceOf[SparseNDArray])
+    assert(nd2.toArray sameElements Array(2f, 4f, 6f))
+  }
+
+  test("Test DataIter") {
+    val nd = NDArray.array(Array(1f, 2f, 3f), Shape(1, 
3)).toSparse(Some(SparseFormat.CSR))
+    val arr = IndexedSeq(nd, nd, nd, nd)
+    val iter = new NDArrayIter(arr)
+    while (iter.hasNext) {
+      val tempArr = iter.next().data
+      tempArr.foreach(ele => {
+        assert(ele.sparseFormat == SparseFormat.CSR)
+        assert(ele.shape == Shape(1, 3))
+      })
+    }
+  }
+
+
+}
diff --git 
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc 
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
index 9b19fd3..387a0b1 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
@@ -93,6 +93,31 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxNDArrayCreateEx
   return ret;
 }
 
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateSparseEx
+  (JNIEnv *env, jobject obj, jint storageType, jintArray shape, jint ndim, 
jint devType,
+    jint devId, jint delayAlloc, jint dtype, jint numAux, jintArray auxTypes,
+    jintArray auxNdims, jintArray auxShapes, jobject ndArrayHandle) {
+    jint *shapeArr = env->GetIntArrayElements(shape, NULL);
+    jint *auxTypesArr = env->GetIntArrayElements(auxTypes, NULL);
+    jint *auxNdimsArr = env->GetIntArrayElements(auxNdims, NULL);
+    jint *auxShapesArr = env->GetIntArrayElements(auxShapes, NULL);
+    NDArrayHandle out;
+    int ret = MXNDArrayCreateSparseEx(storageType,
+     reinterpret_cast<const mx_uint *>(shapeArr),
+     static_cast<mx_uint>(ndim),
+     devType, devId, delayAlloc, dtype,
+     static_cast<mx_uint>(numAux),
+     reinterpret_cast<int *>(auxTypesArr),
+     reinterpret_cast<mx_uint *>(auxNdimsArr),
+     reinterpret_cast<const mx_uint *>(auxShapesArr),  &out);
+    env->ReleaseIntArrayElements(shape, shapeArr, 0);
+    env->ReleaseIntArrayElements(auxTypes, auxTypesArr, 0);
+    env->ReleaseIntArrayElements(auxNdims, auxNdimsArr, 0);
+    env->ReleaseIntArrayElements(auxShapes, auxShapesArr, 0);
+    SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
+    return ret;
+}
+
 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayWaitAll(JNIEnv 
*env, jobject obj) {
   return MXNDArrayWaitAll();
 }
@@ -179,10 +204,10 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxFuncGetInfo
   return ret;
 }
 
-JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvoke
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvokeEx
   (JNIEnv *env, jobject obj, jlong funcPtr, jlongArray inputs,
     jlongArray outputsGiven, jobject outputs, jint numParams,
-    jobjectArray paramKeys, jobjectArray paramVals) {
+    jobjectArray paramKeys, jobjectArray paramVals, jobject outStypes) {
 
   const char **cParamKeys = NULL;
   const char **cParamVals = NULL;
@@ -204,6 +229,7 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxImperativeInvoke
   int numOutputs = 0;
   jlong *cOutputsGiven = NULL;
   NDArrayHandle *cOutputs = NULL;
+  const int *cOutStypes;
   if (outputsGiven) {
     cOutputsGiven = env->GetLongArrayElements(outputsGiven, NULL);
     cOutputs = reinterpret_cast<NDArrayHandle *>(cOutputsGiven);
@@ -211,14 +237,15 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxImperativeInvoke
   }
   jlong *cInputs = env->GetLongArrayElements(inputs, NULL);
   jsize numInputs = env->GetArrayLength(inputs);
-  int ret = MXImperativeInvoke(reinterpret_cast<AtomicSymbolCreator>(funcPtr),
+  int ret = 
MXImperativeInvokeEx(reinterpret_cast<AtomicSymbolCreator>(funcPtr),
                                static_cast<int>(numInputs),
                                reinterpret_cast<NDArrayHandle *>(cInputs),
                                &numOutputs,
                                &cOutputs,
                                static_cast<int>(numParams),
                                cParamKeys,
-                               cParamVals);
+                               cParamVals,
+                               &cOutStypes);
   env->ReleaseLongArrayElements(inputs, cInputs, 0);
   if (cOutputsGiven) {
     env->ReleaseLongArrayElements(outputsGiven, cOutputsGiven, 0);
@@ -240,7 +267,9 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxImperativeInvoke
 
   if (cOutputs) {
     jclass longCls = env->FindClass("java/lang/Long");
+    jclass intCls = env->FindClass("java/lang/Integer");
     jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
+    jmethodID intConst = env->GetMethodID(intCls, "<init>", "(I)V");
     // scala.collection.mutable.ListBuffer append method
     jclass listClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
     jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq",
@@ -249,6 +278,9 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxImperativeInvoke
       env->CallObjectMethod(outputs, listAppend,
                             env->NewObject(longCls, longConst,
                             reinterpret_cast<uint64_t>(cOutputs[i])));
+      env->CallObjectMethod(outStypes, listAppend,
+                            env->NewObject(intCls, intConst,
+                            cOutStypes[i]));
     }
   }
 
@@ -379,6 +411,14 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape
   return ret;
 }
 
+JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromNDArray
+  (JNIEnv *env, jobject obj, jlong dstPtr, jlong srcPtr, jint locator) {
+  int ret = 
MXNDArraySyncCopyFromNDArray(reinterpret_cast<NDArrayHandle>(dstPtr),
+                                   reinterpret_cast<NDArrayHandle>(srcPtr),
+                                   locator);
+  return ret;
+}
+
 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyToCPU
   (JNIEnv *env, jobject obj, jlong ndArrayPtr, jbyteArray data, jint size) {
   jbyte *pdata = env->GetByteArrayElements(data, NULL);
@@ -434,6 +474,25 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFro
   return ret;
 }
 
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDataNDArray
+  (JNIEnv *env, jobject obj, jlong arrayPtr, jobject ndArrayHandle) {
+  NDArrayHandle out;
+  int ret = MXNDArrayGetDataNDArray(reinterpret_cast<NDArrayHandle>(arrayPtr),
+                                     &out);
+  SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
+  return ret;
+}
+
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetAuxNDArray
+  (JNIEnv *env, jobject obj, jlong arrayPtr, jint location, jobject 
ndArrayHandle) {
+  NDArrayHandle out;
+  int ret = MXNDArrayGetAuxNDArray(reinterpret_cast<NDArrayHandle>(arrayPtr),
+                                   static_cast<mx_uint>(location),
+                                   &out);
+  SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
+  return ret;
+}
+
 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetContext
   (JNIEnv *env, jobject obj, jlong arrayPtr, jobject devTypeId, jobject devId) 
{
   int outDevType;
@@ -540,6 +599,14 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxNDArrayGetDType
   return ret;
 }
 
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetStorageType
+  (JNIEnv * env, jobject obj, jlong jhandle, jobject jstype) {
+  int stype;
+  int ret = MXNDArrayGetStorageType(reinterpret_cast<NDArrayHandle>(jhandle), 
&stype);
+  SetIntField(env, jstype, stype);
+  return ret;
+}
+
 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxInitPSEnv
   (JNIEnv *env, jobject obj, jobjectArray jkeys, jobjectArray jvals) {
   // keys and values
diff --git 
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h 
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
index fac32bb..c8ee0ce 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
@@ -41,11 +41,11 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_nnGetOpHandle
 
 /*
  * Class:     org_apache_mxnet_LibInfo
- * Method:    mxImperativeInvoke
- * Signature: 
(J[J[JLscala/collection/mutable/ArrayBuffer;I[Ljava/lang/String;[Ljava/lang/String;)I
+ * Method:    mxImperativeInvokeEx
+ * Signature: 
(J[J[JLscala/collection/mutable/ArrayBuffer;I[Ljava/lang/String;[Ljava/lang/String;Lscala/collection/mutable/ArrayBuffer;)I
  */
-JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvoke
-  (JNIEnv *, jobject, jlong, jlongArray, jlongArray, jobject, jint, 
jobjectArray, jobjectArray);
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvokeEx
+  (JNIEnv *, jobject, jlong, jlongArray, jlongArray, jobject, jint, 
jobjectArray, jobjectArray, jobject);
 
 /*
  * Class:     org_apache_mxnet_LibInfo
@@ -73,6 +73,14 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxNDArrayCreateEx
 
 /*
  * Class:     org_apache_mxnet_LibInfo
+ * Method:    mxNDArrayCreateSparseEx
+ * Signature: (I[IIIIIII[I[I[ILorg/apache/mxnet/Base/RefLong;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateSparseEx
+  (JNIEnv *, jobject, jint, jintArray, jint, jint, jint, jint, jint, jint, 
jintArray, jintArray, jintArray, jobject);
+
+/*
+ * Class:     org_apache_mxnet_LibInfo
  * Method:    mxNDArrayWaitAll
  * Signature: ()I
  */
@@ -137,6 +145,14 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape
 
 /*
  * Class:     org_apache_mxnet_LibInfo
+ * Method:    mxNDArraySyncCopyFromNDArray
+ * Signature: (JJI)I
+ */
+JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromNDArray
+  (JNIEnv *, jobject, jlong, jlong, jint);
+
+/*
+ * Class:     org_apache_mxnet_LibInfo
  * Method:    mxNDArraySyncCopyToCPU
  * Signature: (J[BI)I
  */
@@ -201,6 +217,22 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxNDArraySave
 
 /*
  * Class:     org_apache_mxnet_LibInfo
+ * Method:    mxNDArrayGetDataNDArray
+ * Signature: (JLorg/apache/mxnet/Base/RefLong;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDataNDArray
+  (JNIEnv *, jobject, jlong, jobject);
+
+/*
+ * Class:     org_apache_mxnet_LibInfo
+ * Method:    mxNDArrayGetAuxNDArray
+ * Signature: (JILorg/apache/mxnet/Base/RefLong;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetAuxNDArray
+  (JNIEnv *, jobject, jlong, jint, jobject);
+
+/*
+ * Class:     org_apache_mxnet_LibInfo
  * Method:    mxNDArrayGetContext
  * Signature: (JLorg/apache/mxnet/Base/RefInt;Lorg/apache/mxnet/Base/RefInt;)I
  */
@@ -233,6 +265,14 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxNDArrayGetDType
 
 /*
  * Class:     org_apache_mxnet_LibInfo
+ * Method:    mxNDArrayGetStorageType
+ * Signature: (JLorg/apache/mxnet/Base/RefInt;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetStorageType
+  (JNIEnv *, jobject, jlong, jobject);
+
+/*
+ * Class:     org_apache_mxnet_LibInfo
  * Method:    mxInitPSEnv
  * Signature: ([Ljava/lang/String;[Ljava/lang/String;)I
  */

Reply via email to