This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0038473  [MXNET-1222] Scala Inference enable different shapes input 
(#13330)
0038473 is described below

commit 0038473e799bccd77f57718eb5f8af28b81c8284
Author: Lanking <lanking...@live.com>
AuthorDate: Thu Nov 29 16:16:45 2018 -0800

    [MXNET-1222] Scala Inference enable different shapes input (#13330)
    
    * init commit with Predictor Improvement
    
    * add predictor Example
    
    * change into dArr
    
    * add img config
    
    * add new line and fix code style
    
    important bug fixes
---
 .../src/main/scala/org/apache/mxnet/Executor.scala |  4 +-
 .../infer/predictor/PredictorExample.scala         | 92 ++++++++++++++++++++++
 .../ImageClassifierExampleSuite.scala              |  5 +-
 .../ObjectDetectorExampleSuite.scala               |  5 +-
 .../PredictorExampleSuite.scala}                   | 67 +++++++++-------
 .../scala/org/apache/mxnet/infer/Predictor.scala   | 31 ++++++--
 .../org/apache/mxnet/infer/javaapi/Predictor.scala |  2 +-
 7 files changed, 160 insertions(+), 46 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
index b342a96..85f45bc 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
@@ -106,9 +106,9 @@ class Executor private[mxnet](private[mxnet] val handle: 
ExecutorHandle,
                         "is more efficient than the reverse." +
                         "If you really want to up size, set allowUpSizing = 
true " +
                         "to enable allocation of new arrays.")
-          newArgDict = newArgDict + (name -> NDArray.empty(newShape, 
arr.context))
+          newArgDict = newArgDict + (name -> NDArray.empty(newShape, 
arr.context, arr.dtype))
           if (dArr != null) {
-            newGradDict = newGradDict + (name -> NDArray.empty(newShape, 
dArr.context))
+            newGradDict = newGradDict + (name -> NDArray.empty(newShape, 
dArr.context, dArr.dtype))
           }
         } else {
           newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray))
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/predictor/PredictorExample.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/predictor/PredictorExample.scala
new file mode 100644
index 0000000..be90936
--- /dev/null
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/predictor/PredictorExample.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnetexamples.infer.predictor
+
+import java.io.File
+
+import scala.io
+import org.apache.mxnet._
+import org.apache.mxnet.infer.Predictor
+import org.apache.mxnetexamples.benchmark.CLIParserBase
+import org.kohsuke.args4j.{CmdLineParser, Option}
+
+import scala.collection.JavaConverters._
+
+object PredictorExample {
+
+  def loadModel(modelPathPrefix : String, inputDesc : IndexedSeq[DataDesc],
+                context : Context, epoch : Int): Predictor = {
+    new Predictor(modelPathPrefix, inputDesc, context, Some(epoch))
+  }
+
+  def doInference(predictor : Predictor, imageND : NDArray): 
IndexedSeq[NDArray] = {
+    predictor.predictWithNDArray(IndexedSeq(imageND))
+  }
+
+  def preProcess(imagePath: String, h: Int, w: Int) : NDArray = {
+    var img = Image.imRead(imagePath)
+    img = Image.imResize(img, h, w)
+    // HWC -> CHW
+    img = NDArray.api.transpose(img, Some(Shape(2, 0, 1)))
+    img = NDArray.api.expand_dims(img, 0)
+    img.asType(DType.Float32)
+  }
+
+  def postProcess(modelPathPrefix : String, result : Array[Float]) : String = {
+    val dirPath = modelPathPrefix.substring(0, 1 + 
modelPathPrefix.lastIndexOf(File.separator))
+    val d = new File(dirPath)
+    require(d.exists && d.isDirectory, s"directory: $dirPath not found")
+    val f = io.Source.fromFile(dirPath + "synset.txt")
+    val s = f.getLines().toIndexedSeq
+    val maxIdx = result.zipWithIndex.maxBy(_._1)._2
+    printf(s"Predict Result ${s(maxIdx)} with prob ${result(maxIdx)}\n")
+    s(maxIdx)
+  }
+
+  def main(args : Array[String]): Unit = {
+    val inst = new CLIParser
+    val parser: CmdLineParser = new CmdLineParser(inst)
+
+    parser.parseArgument(args.toList.asJava)
+
+    var context = Context.cpu()
+    if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
+      System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
+      context = Context.gpu()
+    }
+
+    val imgWidth = 224
+    val imgHeight = 224
+
+    val inputDesc = IndexedSeq(new DataDesc("data", Shape(1, 3, imgHeight, 
imgWidth),
+      DType.Float32, Layout.NCHW))
+
+    val predictor = loadModel(inst.modelPathPrefix, inputDesc, context, 0)
+    val img = preProcess(inst.inputImagePath, imgHeight, imgWidth)
+    val result = doInference(predictor, img)(0).toArray
+    postProcess(inst.modelPathPrefix, result)
+  }
+
+}
+
+class CLIParser extends CLIParserBase{
+  @Option(name = "--model-path-prefix", usage = "the input model directory")
+  val modelPathPrefix: String = "/resnet-152/resnet-152"
+  @Option(name = "--input-image", usage = "the input image")
+  val inputImagePath: String = "/images/kitten.jpg"
+}
diff --git 
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
index d8631df..27d9bb4 100644
--- 
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
+++ 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
@@ -20,10 +20,7 @@ package org.apache.mxnetexamples.infer.imageclassifier
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
 import org.slf4j.LoggerFactory
 import java.io.File
-import java.net.URL
-
-import org.apache.commons.io.FileUtils
-import org.apache.mxnet.{Context, NDArrayCollector}
+import org.apache.mxnet.Context
 import org.apache.mxnetexamples.Util
 
 import sys.process.Process
diff --git 
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala
 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala
index addc837..bd960bd 100644
--- 
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala
+++ 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/objectdetector/ObjectDetectorExampleSuite.scala
@@ -18,10 +18,7 @@
 package org.apache.mxnetexamples.infer.objectdetector
 
 import java.io.File
-import java.net.URL
-
-import org.apache.commons.io.FileUtils
-import org.apache.mxnet.{Context, NDArrayCollector}
+import org.apache.mxnet.Context
 import org.apache.mxnetexamples.Util
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
 import org.slf4j.LoggerFactory
diff --git 
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/predictor/PredictorExampleSuite.scala
similarity index 51%
copy from 
scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
copy to 
scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/predictor/PredictorExampleSuite.scala
index d8631df..97ca33e 100644
--- 
a/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExampleSuite.scala
+++ 
b/scala-package/examples/src/test/scala/org/apache/mxnetexamples/infer/predictor/PredictorExampleSuite.scala
@@ -15,27 +15,22 @@
  * limitations under the License.
  */
 
-package org.apache.mxnetexamples.infer.imageclassifier
+package org.apache.mxnetexamples.infer.predictor
 
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
-import org.slf4j.LoggerFactory
 import java.io.File
-import java.net.URL
 
-import org.apache.commons.io.FileUtils
-import org.apache.mxnet.{Context, NDArrayCollector}
+import org.apache.mxnet._
 import org.apache.mxnetexamples.Util
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.slf4j.LoggerFactory
 
-import sys.process.Process
-
-/**
-  * Integration test for imageClassifier example.
-  * This will run as a part of "make scalatest"
-  */
-class ImageClassifierExampleSuite extends FunSuite with BeforeAndAfterAll {
-  private val logger = 
LoggerFactory.getLogger(classOf[ImageClassifierExampleSuite])
+class PredictorExampleSuite extends FunSuite with BeforeAndAfterAll {
+  private val logger = LoggerFactory.getLogger(classOf[PredictorExampleSuite])
+  private var modelDirPrefix = ""
+  private var inputImagePath = ""
+  private var context = Context.cpu()
 
-  test("testImageClassifierExample") {
+  override def beforeAll(): Unit = {
     logger.info("Downloading resnet-18 model")
 
     val tempDirPath = System.getProperty("java.io.tmpdir")
@@ -52,27 +47,41 @@ class ImageClassifierExampleSuite extends FunSuite with 
BeforeAndAfterAll {
     
Util.downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg";,
       tempDirPath + "/inputImages/resnet18/Pug-Cookie.jpg")
 
-    val modelDirPath = tempDirPath + File.separator + "resnet18/"
-    val inputImagePath = tempDirPath + File.separator +
+    modelDirPrefix = tempDirPath + File.separator + "resnet18/resnet-18"
+    inputImagePath = tempDirPath + File.separator +
       "inputImages/resnet18/Pug-Cookie.jpg"
-    val inputImageDir = tempDirPath + File.separator + "inputImages/resnet18/"
 
-    var context = Context.cpu()
     if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
       System.getenv("SCALA_TEST_ON_GPU").toInt == 1) {
       context = Context.gpu()
     }
+    val props = System.getProperties
+    props.setProperty("mxnet.disableShapeCheck", "true")
+  }
 
-    val output = ImageClassifierExample.runInferenceOnSingleImage(modelDirPath 
+ "resnet-18",
-     inputImagePath, context)
-
-    val outputList = 
ImageClassifierExample.runInferenceOnBatchOfImage(modelDirPath + "resnet-18",
-        inputImageDir, context)
-
-    Process("rm -rf " + modelDirPath + " " + inputImageDir) !
-
-    assert(output(0).toList.head._1 === "n02110958 pug, pug-dog")
-    assert(outputList(0).toList.head._1 === "n02110958 pug, pug-dog")
+  override def afterAll(): Unit = {
+    val props = System.getProperties
+    props.setProperty("mxnet.disableShapeCheck", "false")
+  }
 
+  test("test Predictor With Fixed Shape and random shape") {
+    val inputDesc = IndexedSeq(new DataDesc("data", Shape(1, 3, 224, 224),
+      DType.Float32, Layout.NCHW))
+    val predictor = PredictorExample.loadModel(modelDirPrefix, inputDesc, 
context, 0)
+    // fix size
+    var img = PredictorExample.preProcess(inputImagePath, 224, 224)
+    var result = PredictorExample.doInference(predictor, img)(0)
+    var top1 = PredictorExample.postProcess(modelDirPrefix, result.toArray)
+    assert(top1 === "n02110958 pug, pug-dog")
+    // random size 512
+    img = PredictorExample.preProcess(inputImagePath, 512, 512)
+    result = PredictorExample.doInference(predictor, img)(0)
+    top1 = PredictorExample.postProcess(modelDirPrefix, result.toArray)
+    assert(top1 === "n02110958 pug, pug-dog")
+    // original size
+    img = PredictorExample.preProcess(inputImagePath, 1024, 576)
+    result = PredictorExample.doInference(predictor, img)(0)
+    top1 = PredictorExample.postProcess(modelDirPrefix, result.toArray)
+    assert(top1 === "n02110958 pug, pug-dog")
   }
 }
diff --git 
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala 
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
index e2a0e7c..d4bce9f 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
@@ -22,8 +22,10 @@ import org.apache.mxnet.{Context, DataDesc, NDArray, Shape}
 import org.apache.mxnet.module.Module
 
 import scala.collection.mutable.ListBuffer
+import scala.util.Try
 import org.slf4j.LoggerFactory
 
+
 /**
  * Base Trait for MXNet Predictor classes.
  */
@@ -76,6 +78,21 @@ class Predictor(modelPathPrefix: String,
 
   private val logger = LoggerFactory.getLogger(classOf[Predictor])
 
+  /*
+    By setting -Dmxnet.disableShapeCheck=true would disable the data Shape
+    Check of the predictor. Some model may allow different lens of the data
+    such as Seq2Seq, however there maybe risk of crashes if the lens beyond
+    the acceptable range of the model
+   */
+  private val traceProperty = "mxnet.disableShapeCheck"
+  private lazy val shapeCheckDisabled = {
+    val value = 
Try(System.getProperty(traceProperty).toBoolean).getOrElse(false)
+    if (value) {
+      logger.warn("Shape check is disabled (property {} is set)", 
traceProperty)
+    }
+    value
+  }
+
   require(inputDescriptors.head.layout.size != 0, "layout size should not be 
zero")
 
   protected[infer] var batchIndex = inputDescriptors(0).layout.indexOf('N')
@@ -172,9 +189,11 @@ class Predictor(modelPathPrefix: String,
     for((i, d) <- inputBatch.zip(iDescriptors)) {
        require(inputBatch(0).shape(batchIndex) == i.shape(batchIndex),
          "All inputs should be of same batch size")
-      require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1),
-        s"Input Data Shape: ${i.shape} should match the inputDescriptor " +
-          s"shape: ${d.shape} except batchSize")
+      if (!shapeCheckDisabled) {
+        require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1),
+          s"Input Data Shape: ${i.shape} should match the inputDescriptor " +
+            s"shape: ${d.shape} except batchSize")
+      }
     }
 
     val inputBatchSize = inputBatch(0).shape(batchIndex)
@@ -182,8 +201,8 @@ class Predictor(modelPathPrefix: String,
     // rebind with the new batchSize
     if (batchSize != inputBatchSize) {
       logger.info(s"Latency increased due to batchSize mismatch $batchSize vs 
$inputBatchSize")
-      val desc = iDescriptors.map((f : DataDesc) => new DataDesc(f.name,
-        Shape(f.shape.toVector.patch(batchIndex, Vector(inputBatchSize), 1)), 
f.dtype, f.layout) )
+      val desc = inputBatch.zip(iDescriptors).map(f => new DataDesc(f._2.name,
+        f._1.shape, f._2.dtype, f._2.layout))
       mxNetHandler.execute(mod.bind(desc, forceRebind = true,
         forTraining = false))
     }
@@ -200,7 +219,7 @@ class Predictor(modelPathPrefix: String,
 
   private[infer] def loadModule(): Module = {
     val mod = mxNetHandler.execute(Module.loadCheckpoint(modelPathPrefix, 
epoch.get,
-      contexts = contexts))
+      contexts = contexts, dataNames = inputDescriptors.map(desc => 
desc.name)))
     mxNetHandler.execute(mod.bind(inputDescriptors, forTraining = false))
     mod
   }
diff --git 
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
 
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
index c867168..8c48742 100644
--- 
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
+++ 
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/javaapi/Predictor.scala
@@ -93,7 +93,7 @@ class Predictor private[mxnet] (val predictor: 
org.apache.mxnet.infer.Predictor)
     * This method is useful when the input is a batch of data
     * Note: User is responsible for managing allocation/deallocation of 
input/output NDArrays.
     *
-    * @param input       List of NDArrays
+    * @param input             List of NDArrays
     * @return                  Output of predictions as NDArrays
     */
   def predictWithNDArray(input: java.util.List[NDArray]):

Reply via email to