This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch v1.2.0-java
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit c887376f24df1e3ce941f600c69b38026e71771d
Author: Yizhi Liu <yizhi...@amazon.com>
AuthorDate: Mon May 21 13:48:31 2018 -0700

    add Builder and varargs which are java-friendly
---
 .../core/src/main/scala/org/apache/mxnet/IO.scala  | 56 +++++++++++++++++++++-
 .../src/main/scala/org/apache/mxnet/NDArray.scala  |  3 +-
 .../src/main/scala/org/apache/mxnet/Shape.scala    |  4 ++
 .../src/main/scala/org/apache/mxnet/Symbol.scala   |  1 -
 .../scala/org/apache/mxnet/module/BaseModule.scala | 30 ++++++++++++
 .../scala/org/apache/mxnet/module/Module.scala     | 43 ++++++++++++++++-
 .../main/scala/org/apache/mxnet/NDArrayMacro.scala | 52 ++------------------
 7 files changed, 138 insertions(+), 51 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
index 7a9c1a7..123e2f8 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
@@ -19,9 +19,10 @@ package org.apache.mxnet
 
 import org.apache.mxnet.Base._
 import org.apache.mxnet.DType.DType
-import org.apache.mxnet.io.{MXDataPack, MXDataIter}
+import org.apache.mxnet.io.{MXDataIter, MXDataPack}
 import org.slf4j.LoggerFactory
 
+import scala.annotation.varargs
 import scala.collection.immutable.ListMap
 import scala.collection.mutable.ListBuffer
 
@@ -140,6 +141,7 @@ class DataBatch(val data: IndexedSeq[NDArray],
                 // (must match the order of input data/label)
                 private val providedData: ListMap[String, Shape] = null,
                 private val providedLabel: ListMap[String, Shape] = null) {
+
   /**
    * Dispose its data and labels
    * The object shall never be used after it is disposed.
@@ -160,6 +162,58 @@ class DataBatch(val data: IndexedSeq[NDArray],
   def provideLabel: ListMap[String, Shape] = providedLabel
 }
 
+object DataBatch {
+  class Builder() {
+    private var data: IndexedSeq[NDArray] = null
+    private var label: IndexedSeq[NDArray] = null
+    private var index: IndexedSeq[Long] = null
+    private var pad: Int = 0
+    private var bucketKey: AnyRef = null
+    private var providedData: ListMap[String, Shape] = ListMap.empty
+    private var providedLabel: ListMap[String, Shape] = ListMap.empty
+
+    @varargs def setData(data: NDArray*): Builder = {
+      this.data = data.toIndexedSeq
+      this
+    }
+
+    @varargs def setLabel(label: NDArray*): Builder = {
+      this.label = label.toIndexedSeq
+      this
+    }
+
+    @varargs def setIndex(index: Long*): Builder = {
+      this.index = index.toIndexedSeq
+      this
+    }
+
+    def setPad(pad: Int): Builder = {
+      this.pad = pad
+      this
+    }
+
+    def setBucketKey(bucketKey: AnyRef): Builder = {
+      this.bucketKey = bucketKey
+      this
+    }
+
+    def provideData(name: String, shape: Shape): Builder = {
+      providedData = providedData.updated(name, shape)
+      this
+    }
+
+    def provideLabel(name: String, shape: Shape): Builder = {
+      providedLabel = providedLabel.updated(name, shape)
+      this
+    }
+
+    def build(): DataBatch = {
+      new DataBatch(data, label, index, pad,
+        bucketKey, providedData, providedLabel)
+    }
+  }
+}
+
 /**
  * DataIter object in mxnet.
  */
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 416f2d7..e8c687e 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -48,6 +48,7 @@ object NDArray {
     }
   }
 
+  //  private[mxnet] def genericNDArrayFunctionInvoke(
   /**
    * Used by NDArrayMacro.
    * Invoke this function by passing in parameters.
@@ -57,7 +58,7 @@ object NDArray {
    * @param kwargs Key-value arguments of input scalars
    * @return The result NDArrays of result of computation.
    */
-  private[mxnet] def genericNDArrayFunctionInvoke(
+  def genericNDArrayFunctionInvoke(
     funcName: String, args: Seq[Any], kwargs: Map[String, Any] = null): 
NDArrayFuncReturn = {
     val function = functions(funcName)
     val ndArgs = ArrayBuffer.empty[NDArray]
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
index e632ade..6891762 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
@@ -17,6 +17,8 @@
 
 package org.apache.mxnet
 
+import scala.annotation.varargs
+
 /**
  * Shape of [[NDArray]] or other data
  */
@@ -28,6 +30,7 @@ class Shape(dims: Traversable[Int]) extends Serializable {
   }
 
   def apply(dim: Int): Int = shape(dim)
+  def get(dim: Int): Int = apply(dim)
   def size: Int = shape.size
   def length: Int = shape.length
   def drop(dim: Int): Shape = new Shape(shape.drop(dim))
@@ -56,4 +59,5 @@ class Shape(dims: Traversable[Int]) extends Serializable {
 object Shape {
   def apply(dims: Int *): Shape = new Shape(dims: _*)
   def apply(dims: Traversable[Int]): Shape = new Shape(dims)
+  @varargs def create(dims: Int*): Shape = new Shape(dims)
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 13f85a7..b6947b4 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -101,7 +101,6 @@ class Symbol private(private[mxnet] val handle: 
SymbolHandle) extends WarnIfNotD
     var index: Int = -1
     for ((output, i) <- listOutputs().view.zipWithIndex) {
       if (output == name) {
-        require(index == -1, s"There are multiple outputs with name $name")
         index = i
       }
     }
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
index 108cff4..f7ae883 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
@@ -23,6 +23,8 @@ import org.apache.mxnet.optimizer.SGD
 import org.apache.mxnet._
 import org.slf4j.LoggerFactory
 import org.slf4j.Logger
+
+import scala.annotation.varargs
 import scala.collection.mutable.ArrayBuffer
 
 object BaseModule {
@@ -468,6 +470,10 @@ abstract class BaseModule {
    */
   def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit
 
+  def forward(dataBatch: DataBatch, isTrain: Boolean): Unit = {
+    forward(dataBatch, Option(isTrain))
+  }
+
   /**
    * Backward computation.
    * @param outGrads Gradient on the outputs to be propagated back.
@@ -549,6 +555,30 @@ abstract class BaseModule {
            forceRebind: Boolean = false, sharedModule: Option[BaseModule] = 
None,
            gradReq: String = "write"): Unit
 
+
+  protected var labelShapesPartial: IndexedSeq[DataDesc] = _
+  protected var sharedModulePartial: BaseModule = _
+  protected var gradReqPartial: String = "write"
+  @varargs def bindPartial(labelShape: DataDesc*): BaseModule = {
+    labelShapesPartial = labelShape.toIndexedSeq
+    this
+  }
+  def bindPartial(sharedModule: BaseModule): BaseModule = {
+    sharedModulePartial = sharedModule
+    this
+  }
+  def bindPartial(gradReq: String): BaseModule = {
+    gradReqPartial = gradReq
+    this
+  }
+
+  @varargs def bind(forTraining: Boolean, inputsNeedGrad: Boolean,
+                    forceRebind: Boolean, dataShape: DataDesc*): Unit = {
+    bind(dataShape.toVector, Option(labelShapesPartial),
+      forTraining, inputsNeedGrad, forceRebind,
+      Option(sharedModulePartial), gradReqPartial)
+  }
+
   // Install and initialize optimizers.
   def initOptimizer(kvstore: String = "local", optimizer: Optimizer = new 
SGD(),
                     resetOptimizer: Boolean = true, forceInit: Boolean = 
false): Unit
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
index ac3d645..a46b605 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
@@ -17,13 +17,16 @@
 
 package org.apache.mxnet.module
 
-import java.io.{FileInputStream, BufferedInputStream, BufferedOutputStream, 
FileOutputStream}
+import java.io.{BufferedInputStream, BufferedOutputStream, FileInputStream, 
FileOutputStream}
+
 import org.apache.mxnet.DType.DType
 import org.apache.mxnet._
 import org.apache.mxnet.module.DataParallelExecutorGroup.Builder
 import org.apache.mxnet.optimizer.SGD
 import org.slf4j.LoggerFactory
 
+import scala.annotation.varargs
+
 /**
  * Module is a basic module that wrap a `Symbol`. It is functionally the same
  * as the `FeedForward` model, except under the module API.
@@ -642,4 +645,42 @@ object Module {
     }
     mod
   }
+
+  class Builder (private val modelDef: Symbol) {
+    private var dataNames: IndexedSeq[String] = IndexedSeq("data")
+    private var labelNames: IndexedSeq[String] = IndexedSeq("softmax_label")
+    private var contexts: Array[Context] = Array(Context.cpu())
+    private var workLoadList: IndexedSeq[Float] = _
+    private var fixedParamNames: Set[String] = _
+
+    @varargs def setContext(ctx: Context*): Builder = {
+      contexts = ctx.toArray
+      this
+    }
+
+    @varargs def setDataNames(name: String*): Builder = {
+      dataNames = name.toVector
+      this
+    }
+
+    @varargs def setLabelNames(name: String*): Builder = {
+      labelNames = name.toVector
+      this
+    }
+
+    @varargs def setWorkLoadList(workload: Float*): Builder = {
+      workLoadList = workload.toVector
+      this
+    }
+
+    @varargs def setFixedParamNames(name: String*): Builder = {
+      fixedParamNames = name.toSet
+      this
+    }
+
+    def build(): Module = {
+      new Module(modelDef, dataNames, labelNames, contexts,
+                 Option(workLoadList), Option(fixedParamNames))
+    }
+  }
 }
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
index c26d14c..c4d16bc 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
@@ -52,18 +52,6 @@ private[mxnet] object NDArrayMacro {
       else ndarrayFunctions.filter(!_._1.startsWith("_contrib_"))
     }
 
-    val AST_NDARRAY_TYPE = Select(Select(Select(
-      Ident(TermName("org")), TermName("apache")), TermName("mxnet")), 
TypeName("NDArray"))
-    val AST_TYPE_MAP_STRING_ANY = AppliedTypeTree(Ident(TypeName("Map")),
-      List(Ident(TypeName("String")), Ident(TypeName("Any"))))
-    val AST_TYPE_ANY_VARARG = AppliedTypeTree(
-      Select(
-        Select(Ident(termNames.ROOTPKG), TermName("scala")),
-        TypeName("<repeated>")
-      ),
-      List(Ident(TypeName("Any")))
-    )
-
     val functionDefs = newNDArrayFunctions flatMap { case (funcName, funcProp) 
=>
       val functionScope = {
         if (isContrib) Modifiers()
@@ -75,45 +63,15 @@ private[mxnet] object NDArrayMacro {
         if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + 
"_contrib_".length())
         else funcName
       }
-
+      val termName = TermName(funcName)
       // It will generate definition something like,
       Seq(
+        // scalastyle:off
         // def transpose(kwargs: Map[String, Any] = null)(args: Any*)
-        DefDef(functionScope, TermName(newName), List(),
-          List(
-            List(
-              ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), 
TermName("kwargs"),
-                AST_TYPE_MAP_STRING_ANY, Literal(Constant(null)))
-            ),
-            List(
-              ValDef(Modifiers(), TermName("args"), AST_TYPE_ANY_VARARG, 
EmptyTree)
-            )
-          ), TypeTree(),
-          Apply(
-            Ident(TermName("genericNDArrayFunctionInvoke")),
-            List(
-              Literal(Constant(funcName)),
-              Ident(TermName("args")),
-              Ident(TermName("kwargs"))
-            )
-          )
-        ),
+        q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = 
{genericNDArrayFunctionInvoke($funcName, args, kwargs)}",
         // def transpose(args: Any*)
-        DefDef(functionScope, TermName(newName), List(),
-          List(
-            List(
-              ValDef(Modifiers(), TermName("args"), AST_TYPE_ANY_VARARG, 
EmptyTree)
-            )
-          ), TypeTree(),
-          Apply(
-            Ident(TermName("genericNDArrayFunctionInvoke")),
-            List(
-              Literal(Constant(funcName)),
-              Ident(TermName("args")),
-              Literal(Constant(null))
-            )
-          )
-        )
+        q"@scala.annotation.varargs def $termName(args: Any*) = 
{genericNDArrayFunctionInvoke($funcName, args, null)}"
+        // scalastyle:on
       )
     }
 

-- 
To stop receiving notification emails like this one, please contact
liuyi...@apache.org.

Reply via email to