This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch v1.4.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.4.x by this push:
     new 69515c2  [v1.4.1] Java bug-fix cherry pick (#14834)
69515c2 is described below

commit 69515c2f9b1ac6fd4b661d5411a97de968cf4e2e
Author: Lanking <lanking...@live.com>
AuthorDate: Mon Apr 29 16:03:36 2019 -0700

    [v1.4.1] Java bug-fix cherry pick (#14834)
    
    * clean up submodule (#14645)
    
    * Scala/Java Predict API fix #14756 (#14804)
    
    * add fix in the code
    
    * add unit test
    
    * update comments
    
    * add fixes to code gen
---
 .../scala/org/apache/mxnet/module/BaseModule.scala |  17 +-
 .../java/org/apache/mxnet/javaapi/NDArrayTest.java |   4 +-
 .../test/scala/org/apache/mxnet/ModuleSuite.scala  |  28 ++++
 .../scala/org/apache/mxnet/APIDocGenerator.scala   | 184 ++++++++++++++-------
 4 files changed, 173 insertions(+), 60 deletions(-)

diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
index b73f4ad..73ccef2 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
@@ -247,11 +247,23 @@ abstract class BaseModule {
 
   /**
    * Run prediction and collect the outputs.
-   * @param evalData
+   * @param evalData dataIter to do the Inference
    * @param numBatch Default is -1, indicating running all the batches in the 
data iterator.
    * @param reset Default is `True`, indicating whether we should reset the 
data iter before start
    *              doing prediction.
    * @return The return value will be a list `[out1, out2, out3]`.
+   *        The concatenation process will be like
+   *        {{{
+   *            outputBatches = [
+   *              [a1, a2, a3], // batch a
+   *              [b1, b2, b3]  // batch b
+   *            ]
+   *            result = [
+   *              NDArray, // [a1, b1]
+   *              NDArray, // [a2, b2]
+   *              NDArray, // [a3, b3]
+   *            ]
+   *        }}}
    *         Where each element is concatenation of the outputs for all the 
mini-batches.
    */
   def predict(evalData: DataIter, numBatch: Int = -1, reset: Boolean = true)
@@ -264,7 +276,8 @@ abstract class BaseModule {
           s"in mini-batches (${out.size})." +
       "Maybe bucketing is used?")
     )
-    val concatenatedOutput = outputBatches.map(out => NDArray.concatenate(out))
+    val oBT = outputBatches.transpose
+    val concatenatedOutput = oBT.map(out => NDArray.concatenate(out))
     outputBatches.foreach(_.foreach(_.dispose()))
     concatenatedOutput
   }
diff --git 
a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java 
b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
index 2659b78..5bbe8bb 100644
--- a/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
+++ b/scala-package/core/src/test/java/org/apache/mxnet/javaapi/NDArrayTest.java
@@ -71,7 +71,7 @@ public class NDArrayTest {
         NDArray$ NDArray = NDArray$.MODULE$;
         float[] arr = new float[]{1.0f, 2.0f, 3.0f};
         NDArray nd = new NDArray(arr, new Shape(new int[]{3}), new 
Context("cpu", 0));
-        float result = NDArray.norm(NDArray.new normParam(nd))[0].toArray()[0];
+        float result = NDArray.norm(new normParam(nd))[0].toArray()[0];
         float cal = 0.0f;
         for (float ele : arr) {
             cal += ele * ele;
@@ -79,7 +79,7 @@ public class NDArrayTest {
         cal = (float) Math.sqrt(cal);
         assertTrue(Math.abs(result - cal) < 1e-5);
         NDArray dotResult = new NDArray(new float[]{0}, new Shape(new 
int[]{1}), new Context("cpu", 0));
-        NDArray.dot(NDArray.new dotParam(nd, nd).setOut(dotResult));
+        NDArray.dot(new dotParam(nd, nd).setOut(dotResult));
         assertTrue(Arrays.equals(dotResult.toArray(), new float[]{14.0f}));
     }
 }
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
index 88e314e..e6ebfd3 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
@@ -23,6 +23,34 @@ import org.apache.mxnet.optimizer._
 import org.apache.mxnet.io._
 
 class ModuleSuite extends FunSuite with BeforeAndAfterAll {
+
+  class myModule(symbol : Symbol) extends Module (symbol) {
+    override def predictEveryBatch(evalData: DataIter,
+                                   numBatch: Int = 1, reset: Boolean = true):
+    IndexedSeq[IndexedSeq[NDArray]] = {
+      val data = IndexedSeq(
+        NDArray.ones(Shape(1, 10, 1)),
+        NDArray.ones(Shape(1, 10, 1)),
+        NDArray.ones(Shape(1, 10, 4))
+      )
+      List.fill(numBatch)(data).toIndexedSeq
+    }
+  }
+
+  test("predict") {
+    val sym = Symbol.Variable("data")
+    val mod = new myModule(sym)
+    val dummyIter = new NDArrayIter(IndexedSeq(NDArray.ones(1)))
+    var output = mod.predict(dummyIter, 1)
+    require(output(0).shape == Shape(1, 10, 1))
+    require(output(1).shape == Shape(1, 10, 1))
+    require(output(2).shape == Shape(1, 10, 4))
+    output = mod.predict(dummyIter, 2)
+    require(output(0).shape == Shape(2, 10, 1))
+    require(output(1).shape == Shape(2, 10, 1))
+    require(output(2).shape == Shape(2, 10, 4))
+  }
+
   test ("model dtype") {
     val dType = DType.Float32
     val dShape = Shape(3, 8, 7)
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
index ce12dc7..77a2704 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
@@ -23,12 +23,16 @@ import java.security.MessageDigest
 import scala.collection.mutable.ListBuffer
 
 /**
-  * This object will generate the Scala documentation of the new Scala API
-  * Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
+  * This object will generate the Scala documentation of the Scala/Java APIs
   * The code will be executed during Macros stage and file live in Core stage
   */
 private[mxnet] object APIDocGenerator extends GeneratorBase {
 
+  /**
+    * Main method used to generate code and write to files
+    * A hash check placed at the end to verify changes
+    * @param args Input args
+    */
   def main(args: Array[String]): Unit = {
     val FILE_PATH = args(0)
     val hashCollector = ListBuffer[String]()
@@ -40,6 +44,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
     val finalHash = hashCollector.mkString("\n")
   }
 
+  /**
+    * Generate MD5 result from an input string
+    * Encoded in UTF-8
+    * @param input The input string
+    * @return A MD5 value from the string
+    */
   def MD5Generator(input: String): String = {
     val md = MessageDigest.getInstance("MD5")
     md.update(input.getBytes("UTF-8"))
@@ -47,6 +57,12 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
     org.apache.commons.codec.binary.Base64.encodeBase64URLSafeString(digest)
   }
 
+  /**
+    * Type-safe class body generation for NDArray/Symbol
+    * @param FILE_PATH File path write the file to
+    * @param isSymbol Check if write the Symbol API, NDArray otherwise
+    * @return MD5 String
+    */
   def typeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
     val generated = typeSafeFunctionsToGenerate(isSymbol, isContrib = false)
       .map { func =>
@@ -57,11 +73,22 @@ private[mxnet] object APIDocGenerator extends GeneratorBase 
{
 
     writeFile(
       FILE_PATH,
-      if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
       "package org.apache.mxnet",
+      if (isSymbol) "SymbolAPIBase" else "NDArrayAPIBase",
+      "import org.apache.mxnet.annotation.Experimental",
       generated)
   }
 
+  /**
+    * Non Type-safe interface of Scala Symbol/NDArray
+    * It includes class definition : e.g class SymbolBase
+    * and function definitions : e.g def softmax(...)(...)(...) : NDArray
+    * Users can directly use the api by calling NDArray.<function_name>
+    * It support both positional input or Map input
+    * @param FILE_PATH File path write the file to
+    * @param isSymbol Check if write the Symbol API, NDArray otherwise
+    * @return MD5 String
+    */
   def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
     val absFuncs = functionsToGenerate(isSymbol, isContrib = false)
       .map { func =>
@@ -85,34 +112,53 @@ private[mxnet] object APIDocGenerator extends 
GeneratorBase {
 
     writeFile(
       FILE_PATH,
-      if (isSymbol) "SymbolBase" else "NDArrayBase",
       "package org.apache.mxnet",
+      if (isSymbol) "SymbolBase" else "NDArrayBase",
+      "import org.apache.mxnet.annotation.Experimental",
       absFuncs)
   }
 
-  def javaClassGen(filePath : String) : String = {
+  /**
+    * Type-safe interface of Java NDArray
+    * @param FILE_PATH File path write the file to
+    * @return MD5 String
+    */
+  def javaClassGen(FILE_PATH : String) : String = {
     val notGenerated = Set("Custom")
     val absClassFunctions = functionsToGenerate(false, false, true)
-    val absFuncs = absClassFunctions.filterNot(ele => 
notGenerated.contains(ele.name))
-      .groupBy(_.name.toLowerCase).map(ele => {
-      /* Pattern matching for not generating deprecated method
-       * Group all method name in lowercase
-       * Kill the capital lettered method such as Cast vs cast
-       * As it defined by default it deprecated
-       */
-      if (ele._2.length == 1) ele._2.head
-      else {
-        if (ele._2.head.name.head.isLower) ele._2.head
-        else ele._2.last
-      }
-    }).map(absClassFunction => {
+    val (absFuncs, paramClassUncleaned) =
+      absClassFunctions.filterNot(ele => notGenerated.contains(ele.name))
+        .groupBy(_.name.toLowerCase).map(ele => {
+        /* Pattern matching for not generating deprecated method
+         * Group all method name in lowercase
+         * Kill the capital lettered method such as Cast vs cast
+         * As it defined by default it deprecated
+         */
+        if (ele._2.length == 1) ele._2.head
+        else {
+          if (ele._2.head.name.head.isLower) ele._2.head
+          else ele._2.last
+        }
+      }).map(absClassFunction => {
         generateJavaAPISignature(absClassFunction)
-      }).toSeq
+      }).toSeq.unzip
+    val paramClass = paramClassUncleaned.filterNot(_.isEmpty)
     val packageName = "NDArrayBase"
     val packageDef = "package org.apache.mxnet.javaapi"
-    writeFile(filePath + "javaapi/", packageName, packageDef, absFuncs)
+    writeFile(
+      FILE_PATH + "javaapi/",
+      packageDef,
+      packageName,
+      "import org.apache.mxnet.annotation.Experimental",
+      absFuncs, Some(paramClass))
   }
 
+  /**
+    * Generate Scala docs from the function description
+    * @param func The function case class
+    * @param withParam Whether to generate param field
+    * @return A formatted string for the function description
+    */
   def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String 
= {
     def fixDesc(desc: String): String = {
       var curDesc = desc
@@ -146,7 +192,15 @@ private[mxnet] object APIDocGenerator extends 
GeneratorBase {
     }
   }
 
-  def generateAPISignature(func: Func, isSymbol: Boolean): String = {
+  /**
+    * Generate the function interface
+    * e.g: def softmax(data: NDArray, name ...): NDArrayFunctionReturn
+    * @param func The function case class
+    * @param isSymbol Check if generate Symbol function, NDArray otherwise
+    * @param typeParameter Type param specifically used in Random Module
+    * @return Formatted string for the function
+    */
+  def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter: 
String = ""): String = {
     val argDef = ListBuffer[String]()
 
     argDef ++= typedFunctionCommonArgDef(func)
@@ -162,10 +216,15 @@ private[mxnet] object APIDocGenerator extends 
GeneratorBase {
     val returnType = func.returnType
 
     s"""@Experimental
-       |def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
+       |def ${func.name}$typeParameter (${argDef.mkString(", ")}): 
$returnType""".stripMargin
   }
 
-  def generateJavaAPISignature(func : Func) : String = {
+  /**
+    * Generate Java function interface
+    * @param func The function case class
+    * @return A formatted string for the function
+    */
+  def generateJavaAPISignature(func : Func) : (String, String) = {
     val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
     var argDef = ListBuffer[String]()
     var classDef = ListBuffer[String]()
@@ -204,54 +263,67 @@ private[mxnet] object APIDocGenerator extends 
GeneratorBase {
            | }
            | def getOut() = this.out
            | """.stripMargin
-      s"""$scalaDocNoParam
-         | $experimentalTag
-         | def ${func.name}(po: ${func.name}Param) : $returnType
-         | /**
-         | * This Param Object is specifically used for ${func.name}
-         | ${requiredParam.mkString("\n")}
-         | */
-         | class ${func.name}Param(${argDef.mkString(",")}) {
-         |  ${classDef.mkString("\n  ")}
-         | }""".stripMargin
+      (s"""$scalaDocNoParam
+          | $experimentalTag
+          | def ${func.name}(po: ${func.name}Param) : $returnType
+          | """.stripMargin,
+        s"""/**
+           | * This Param Object is specifically used for ${func.name}
+           | ${requiredParam.mkString("\n")}
+           | */
+           | class ${func.name}Param(${argDef.mkString(",")}) {
+           |  ${classDef.mkString("\n  ")}
+           | }""".stripMargin)
     } else {
       argDef += "out : NDArray"
-      s"""$scalaDoc
-         |$experimentalTag
-         | def ${func.name}(${argDef.mkString(", ")}) : $returnType
-         | """.stripMargin
+      (s"""$scalaDoc
+          |$experimentalTag
+          | def ${func.name}(${argDef.mkString(", ")}) : $returnType
+          | """.stripMargin, "")
     }
   }
 
-  def writeFile(FILE_PATH: String, className: String, packageDef: String,
-                absFuncs: Seq[String]): String = {
+  /**
+    * Write the formatted string to file
+    * @param FILE_PATH Location of the file writes to
+    * @param packageDef Package definition
+    * @param className Class name
+    * @param imports Packages need to import
+    * @param absFuncs All formatted functions
+    * @return A MD5 string
+    */
+  def writeFile(FILE_PATH: String, packageDef: String, className: String,
+                imports: String, absFuncs: Seq[String],
+                paramClass: Option[Seq[String]] = None): String = {
 
     val finalStr =
       s"""/*
-         |* Licensed to the Apache Software Foundation (ASF) under one or more
-         |* contributor license agreements.  See the NOTICE file distributed 
with
-         |* this work for additional information regarding copyright ownership.
-         |* The ASF licenses this file to You under the Apache License, 
Version 2.0
-         |* (the "License"); you may not use this file except in compliance 
with
-         |* the License.  You may obtain a copy of the License at
-         |*
-         |*    http://www.apache.org/licenses/LICENSE-2.0
-         |*
-         |* Unless required by applicable law or agreed to in writing, software
-         |* distributed under the License is distributed on an "AS IS" BASIS,
-         |* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 
implied.
-         |* See the License for the specific language governing permissions and
-         |* limitations under the License.
-         |*/
+         | * Licensed to the Apache Software Foundation (ASF) under one or more
+         | * contributor license agreements.  See the NOTICE file distributed 
with
+         | * this work for additional information regarding copyright 
ownership.
+         | * The ASF licenses this file to You under the Apache License, 
Version 2.0
+         | * (the "License"); you may not use this file except in compliance 
with
+         | * the License.  You may obtain a copy of the License at
+         | *
+         | *    http://www.apache.org/licenses/LICENSE-2.0
+         | *
+         | * Unless required by applicable law or agreed to in writing, 
software
+         | * distributed under the License is distributed on an "AS IS" BASIS,
+         | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 
implied.
+         | * See the License for the specific language governing permissions 
and
+         | * limitations under the License.
+         | */
          |
          |$packageDef
          |
-         |import org.apache.mxnet.annotation.Experimental
+         |$imports
          |
          |// scalastyle:off
          |abstract class $className {
          |${absFuncs.mkString("\n")}
-         |}""".stripMargin
+         |}
+         |${paramClass.getOrElse(Seq()).mkString("\n")}
+         |""".stripMargin
 
 
     val pw = new PrintWriter(new File(FILE_PATH + s"$className.scala"))

Reply via email to