mariussoutier opened a new issue #12409: Scala: DataDesc 
IllegalArgumentException with simple example
URL: https://github.com/apache/incubator-mxnet/issues/12409
 
 
   ## Description
   (Brief description of the problem in no more than 2 sentences.)
   
   Fitting using Scala and Module API throws an IAE.
   Apparently the label shape of (50) doesn't correspond to expected the NCHW 
format.
   
   ## Environment info (Required)
   
   macOS 10.13.6
   IntelliJ 2018.2.2
   Scala 2.11.12
   Java 1.8.0_121
   MXNet 1.2.1
   
   ## Error Message:
   
   ```scala
   Exception in thread "main" java.lang.IllegalArgumentException: requirement 
failed: number of dimensions in shape :1 with shape: (50) should match the 
length of the layout: 4 with layout: NCHW
        at scala.Predef$.require(Predef.scala:224)
        at org.apache.mxnet.DataDesc.<init>(IO.scala:233)
        at 
org.apache.mxnet.DataDesc$$anonfun$ListMap2Descs$1.apply(IO.scala:256)
        at 
org.apache.mxnet.DataDesc$$anonfun$ListMap2Descs$1.apply(IO.scala:256)
        at 
scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
        at 
scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
        at scala.collection.Iterator$class.foreach(Iterator.scala:891)
        at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
        at scala.collection.IterableLike$class.foreach(IterableLike.scala:72)
        at scala.collection.AbstractIterable.foreach(Iterable.scala:54)
        at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
        at scala.collection.AbstractTraversable.map(Traversable.scala:104)
        at org.apache.mxnet.DataDesc$.ListMap2Descs(IO.scala:256)
        at org.apache.mxnet.module.BaseModule.fit(BaseModule.scala:399)
   ```
   
   
   ## Minimum reproducible example
   
   ```scala
       val trainDataIter = IO.ImageRecordIter(Map(
         "data_name" -> dataName,
         "path_imgrec" -> this.getClass.getResource("/data/mydata.rec").getFile,
         "data_shape" -> "(3,128,128)",
         "batch_size" -> "50"
       ))
   
   val mod = new Module(mlp)
       mod.fit(
         trainDataIter,
         Some(testDataIter),
         numEpoch = 10,
         fitParams =
           new FitParams()
             .setOptimizer(new SGD(0.1f, 0.9f, 0.0001f))
       )
   ```
   
   Tried debugging this, but pretty difficult to find out what's going on with 
a stringly typed API.
   
   `println(trainDataIter.provideData)` -> Map(data -> (50,3,128,128))
   `println(trainDataIter.provideLabel)` -> Map(label -> (50))
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to