This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e703694  Fixes #14181, validate model output shape for ObjectDetector. 
(#14215)
e703694 is described below

commit e70369437bea25b92bd5531b08fff92988b2ff02
Author: Frank Liu <frankfliu2...@gmail.com>
AuthorDate: Thu Mar 7 10:13:09 2019 -0800

    Fixes #14181, validate model output shape for ObjectDetector. (#14215)
---
 .../org/apache/mxnet/infer/ImageClassifier.scala   |  2 ++
 .../org/apache/mxnet/infer/ObjectDetector.scala    | 25 +++++++++++++++++++---
 .../scala/org/apache/mxnet/infer/Predictor.scala   |  7 ++++++
 3 files changed, 31 insertions(+), 3 deletions(-)

diff --git 
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
 
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
index 3c80f92..99c0432 100644
--- 
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
+++ 
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
@@ -66,6 +66,8 @@ class ImageClassifier(modelPathPrefix: String,
   protected[infer] val height = inputShape(inputLayout.indexOf('H'))
   protected[infer] val width = inputShape(inputLayout.indexOf('W'))
 
+  def outputShapes: IndexedSeq[(String, Shape)] = predictor.outputShapes
+
   /**
     * To classify the image according to the provided model
     *
diff --git 
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala
 
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala
index 7146156..e29f068 100644
--- 
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala
+++ 
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ObjectDetector.scala
@@ -20,12 +20,13 @@ package org.apache.mxnet.infer
 // scalastyle:off
 import java.awt.image.BufferedImage
 
+import org.apache.mxnet.Shape
+
 import scala.collection.parallel.mutable.ParArray
 // scalastyle:on
 import org.apache.mxnet.NDArray
 import org.apache.mxnet.DataDesc
 import org.apache.mxnet.Context
-import scala.collection.mutable.ListBuffer
 
 /**
   * The ObjectDetector class helps to run ObjectDetection tasks where the goal
@@ -174,7 +175,25 @@ class ObjectDetector(modelPathPrefix: String,
                          contexts: Array[Context] = Context.cpu(),
                          epoch: Option[Int] = Some(0)):
   ImageClassifier = {
-    new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)
-  }
+    val imageClassifier: ImageClassifier =
+      new ImageClassifier(modelPathPrefix, inputDescriptors, contexts, epoch)
+
+    val shapes: IndexedSeq[(String, Shape)] = imageClassifier.outputShapes
+    if (shapes.length != inputDescriptors.length) {
+      throw new IllegalStateException(s"Invalid output shapes, expected:" +
+        s" $inputDescriptors.length, actual: $shapes.length.")
+    }
+    shapes.map(_._2).foreach(shape => {
+      if (shape.length < 3) {
+        throw new IllegalArgumentException("Invalid output shapes, the model 
doesn't"
+          + " support object detection.")
+      }
+      if (shape.get(2) < 6) {
+        throw new IllegalArgumentException("Invalid output shapes, the model 
doesn't"
+          + " support object detection with bounding box.")
+      }
+    })
 
+    imageClassifier
+  }
 }
diff --git 
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala 
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
index 67692a3..66284c8 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
@@ -56,6 +56,11 @@ private[infer] trait PredictBase {
    */
   def predictWithNDArray(input: IndexedSeq[NDArray]): IndexedSeq[NDArray]
 
+  /**
+    * Get model output shapes.
+    * @return   model output shapes.
+    */
+  def outputShapes: IndexedSeq[(String, Shape)]
 }
 
 /**
@@ -122,6 +127,8 @@ class Predictor(modelPathPrefix: String,
 
   protected[infer] val mod = loadModule()
 
+  override def outputShapes: IndexedSeq[(String, Shape)] = mod.outputShapes
+
   /**
    * Takes input as IndexedSeq one dimensional arrays and creates the NDArray 
needed for inference
    * The array will be reshaped based on the input descriptors.

Reply via email to