This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b482a44  [MXNET-1379] update reshape operator (#14600)
b482a44 is described below

commit b482a44fa8cd932ed48d62bafadb11299c7cd381
Author: Lanking <lanking...@live.com>
AuthorDate: Wed Apr 3 10:07:38 2019 -0700

    [MXNET-1379] update reshape operator (#14600)
    
    * update reshape operator
    
    * Satisfy the Lint God =v=
    
    * update the jni header signature
---
 .../core/src/main/scala/org/apache/mxnet/LibInfo.scala      |  5 +++--
 .../core/src/main/scala/org/apache/mxnet/NDArray.scala      | 13 ++++++++++++-
 .../core/src/test/scala/org/apache/mxnet/NDArraySuite.scala |  8 ++++++--
 .../native/src/main/native/org_apache_mxnet_native_c_api.cc | 13 +++++++------
 .../native/src/main/native/org_apache_mxnet_native_c_api.h  |  8 ++++----
 5 files changed, 32 insertions(+), 15 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
index 20b6ed9..40fc095 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
@@ -86,9 +86,10 @@ private[mxnet] class LibInfo {
   @native def mxNDArrayAt(handle: NDArrayHandle,
                           idx: MXUint,
                           out: NDArrayHandleRef): Int
-  @native def mxNDArrayReshape(handle: NDArrayHandle,
+  @native def mxNDArrayReshape64(handle: NDArrayHandle,
                                nDim: Int,
-                               dims: Array[Int],
+                               dims: Array[Long],
+                               reverse: Boolean,
                                reshapeHandle: NDArrayHandleRef): Int
   @native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle,
                                        source: Array[MXFloat],
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 915e4c6..ab42265 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -950,8 +950,19 @@ class NDArray private[mxnet](private[mxnet] val handle: 
NDArrayHandle,
    * @return a reshaped NDArray that shares memory with current one.
    */
   def reshape(dims: Array[Int]): NDArray = {
+    reshape(dims.map(_.toLong))
+  }
+
+  /**
+    * Return a reshaped NDArray that shares memory with current one.
+    * @param dims New shape.
+    * @param reverse whether to inplace reshape
+    * @return a reshaped NDArray that shares memory with current one.
+    */
+  def reshape(dims: Array[Long], reverse: Option[Boolean] = None): NDArray = {
     val reshapeHandle = new NDArrayHandleRef
-    checkCall(_LIB.mxNDArrayReshape(handle, dims.length, dims, reshapeHandle))
+    checkCall(_LIB.mxNDArrayReshape64(handle,
+      dims.length, dims, reverse.getOrElse(false), reshapeHandle))
     new NDArray(handle = reshapeHandle.value, writable = this.writable)
   }
 
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 206094c..c2ef641 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -878,14 +878,18 @@ class NDArraySuite extends FunSuite with 
BeforeAndAfterAll with Matchers {
   }
 
   test("reshape") {
-    val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Shape(3, 2))
+    var arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Shape(3, 2))
 
-    val arr1 = arr.reshape(Array(2, 3))
+    var arr1 = arr.reshape(Array(2, 3))
     assert(arr1.shape === Shape(2, 3))
     assert(arr1.toArray === Array(1f, 2f, 3f, 4f, 5f, 6f))
 
     arr.set(1f)
     assert(arr1.toArray === Array(1f, 1f, 1f, 1f, 1f, 1f))
+
+    arr = NDArray.ones(1, 384, 1)
+    arr1 = arr.reshape(Array(0, -3))
+    assert(arr1.shape === Shape(1, 384))
   }
 
   test("dispose deps") {
diff --git 
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc 
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
index ea6e9c8..33e4cca 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
@@ -404,14 +404,15 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxNDArrayAt
   return ret;
 }
 
-JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape
-  (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim, jintArray dims, 
jobject reshapedHandle) {
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64
+  (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim,
+   jlongArray dims, jboolean reverse, jobject reshapedHandle) {
   NDArrayHandle out;
-  jint *pdims = env->GetIntArrayElements(dims, NULL);
-  int ret = MXNDArrayReshape(reinterpret_cast<NDArrayHandle>(ndArrayPtr), ndim,
-                                    reinterpret_cast<int *>(pdims), &out);
+  jlong *pdims = env->GetLongArrayElements(dims, NULL);
+  int ret = MXNDArrayReshape64(reinterpret_cast<NDArrayHandle>(ndArrayPtr), 
ndim,
+                                    reinterpret_cast<dim_t *>(pdims), reverse, 
&out);
   SetLongField(env, reshapedHandle, reinterpret_cast<jlong>(out));
-  env->ReleaseIntArrayElements(dims, pdims, 0);
+  env->ReleaseLongArrayElements(dims, pdims, 0);
   return ret;
 }
 
diff --git 
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h 
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
index 7e8e03d..b8a9b3b 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
@@ -161,11 +161,11 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxNDArrayAt
 
 /*
  * Class:     org_apache_mxnet_LibInfo
- * Method:    mxNDArrayReshape
- * Signature: (JI[ILorg/apache/mxnet/Base/RefLong;)I
+ * Method:    mxNDArrayReshape64
+ * Signature: (JI[JZLorg/apache/mxnet/Base/RefLong;)I
  */
-JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape
-  (JNIEnv *, jobject, jlong, jint, jintArray, jobject);
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64
+  (JNIEnv *, jobject, jlong, jint, jlongArray, jboolean, jobject);
 
 /*
  * Class:     org_apache_mxnet_LibInfo

Reply via email to