This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 6cbc273  Scala/Java Predict API fix #14756 (#14804)
6cbc273 is described below

commit 6cbc273454d0e6c24307183eeedc5c3e33e2ccc1
Author: Lanking <lanking...@live.com>
AuthorDate: Fri Apr 26 20:40:02 2019 -0700

    Scala/Java Predict API fix #14756 (#14804)
    
    * add fix in the code
    
    * add unit test
    
    * update comments
---
 .../scala/org/apache/mxnet/module/BaseModule.scala | 17 +++++++++++--
 .../test/scala/org/apache/mxnet/ModuleSuite.scala  | 28 ++++++++++++++++++++++
 2 files changed, 43 insertions(+), 2 deletions(-)

diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
index 3be8e06..7fbdae5 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
@@ -247,11 +247,23 @@ abstract class BaseModule {
 
   /**
    * Run prediction and collect the outputs.
-   * @param evalData
+   * @param evalData dataIter to do the Inference
    * @param numBatch Default is -1, indicating running all the batches in the 
data iterator.
    * @param reset Default is `True`, indicating whether we should reset the 
data iter before start
    *              doing prediction.
    * @return The return value will be a list `[out1, out2, out3]`.
+   *        The concatenation process will be like
+   *        {{{
+   *            outputBatches = [
+   *              [a1, a2, a3], // batch a
+   *              [b1, b2, b3]  // batch b
+   *            ]
+   *            result = [
+   *              NDArray, // [a1, b1]
+   *              NDArray, // [a2, b2]
+   *              NDArray, // [a3, b3]
+   *            ]
+   *        }}}
    *         Where each element is concatenation of the outputs for all the 
mini-batches.
    */
   def predict(evalData: DataIter, numBatch: Int = -1, reset: Boolean = true)
@@ -264,7 +276,8 @@ abstract class BaseModule {
           s"in mini-batches (${out.size})." +
       "Maybe bucketing is used?")
     )
-    val concatenatedOutput = outputBatches.map(out => NDArray.concatenate(out))
+    val oBT = outputBatches.transpose
+    val concatenatedOutput = oBT.map(out => NDArray.concatenate(out))
     outputBatches.foreach(_.foreach(_.dispose()))
     concatenatedOutput
   }
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
index 3e753a1..5aed01b 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
@@ -23,6 +23,34 @@ import org.apache.mxnet.optimizer._
 import org.apache.mxnet.io._
 
 class ModuleSuite extends FunSuite with BeforeAndAfterAll {
+
+  class myModule(symbol : Symbol) extends Module (symbol) {
+    override def predictEveryBatch(evalData: DataIter,
+                                   numBatch: Int = 1, reset: Boolean = true):
+    IndexedSeq[IndexedSeq[NDArray]] = {
+      val data = IndexedSeq(
+        NDArray.ones(Shape(1, 10, 1)),
+        NDArray.ones(Shape(1, 10, 1)),
+        NDArray.ones(Shape(1, 10, 4))
+      )
+      List.fill(numBatch)(data).toIndexedSeq
+    }
+  }
+
+  test("predict") {
+    val sym = Symbol.Variable("data")
+    val mod = new myModule(sym)
+    val dummyIter = new NDArrayIter(IndexedSeq(NDArray.ones(1)))
+    var output = mod.predict(dummyIter, 1)
+    require(output(0).shape == Shape(1, 10, 1))
+    require(output(1).shape == Shape(1, 10, 1))
+    require(output(2).shape == Shape(1, 10, 4))
+    output = mod.predict(dummyIter, 2)
+    require(output(0).shape == Shape(2, 10, 1))
+    require(output(1).shape == Shape(2, 10, 1))
+    require(output(2).shape == Shape(2, 10, 4))
+  }
+
   test ("model dtype") {
     val dType = DType.Float32
     val dShape = Shape(3, 8, 7)

Reply via email to