This is an automated email from the ASF dual-hosted git repository. liuyizhi pushed a commit to branch v1.2.0-java in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit c887376f24df1e3ce941f600c69b38026e71771d Author: Yizhi Liu <yizhi...@amazon.com> AuthorDate: Mon May 21 13:48:31 2018 -0700 add Builder and varargs which are java-friendly --- .../core/src/main/scala/org/apache/mxnet/IO.scala | 56 +++++++++++++++++++++- .../src/main/scala/org/apache/mxnet/NDArray.scala | 3 +- .../src/main/scala/org/apache/mxnet/Shape.scala | 4 ++ .../src/main/scala/org/apache/mxnet/Symbol.scala | 1 - .../scala/org/apache/mxnet/module/BaseModule.scala | 30 ++++++++++++ .../scala/org/apache/mxnet/module/Module.scala | 43 ++++++++++++++++- .../main/scala/org/apache/mxnet/NDArrayMacro.scala | 52 ++------------------ 7 files changed, 138 insertions(+), 51 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 7a9c1a7..123e2f8 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -19,9 +19,10 @@ package org.apache.mxnet import org.apache.mxnet.Base._ import org.apache.mxnet.DType.DType -import org.apache.mxnet.io.{MXDataPack, MXDataIter} +import org.apache.mxnet.io.{MXDataIter, MXDataPack} import org.slf4j.LoggerFactory +import scala.annotation.varargs import scala.collection.immutable.ListMap import scala.collection.mutable.ListBuffer @@ -140,6 +141,7 @@ class DataBatch(val data: IndexedSeq[NDArray], // (must match the order of input data/label) private val providedData: ListMap[String, Shape] = null, private val providedLabel: ListMap[String, Shape] = null) { + /** * Dispose its data and labels * The object shall never be used after it is disposed. @@ -160,6 +162,58 @@ class DataBatch(val data: IndexedSeq[NDArray], def provideLabel: ListMap[String, Shape] = providedLabel } +object DataBatch { + class Builder() { + private var data: IndexedSeq[NDArray] = null + private var label: IndexedSeq[NDArray] = null + private var index: IndexedSeq[Long] = null + private var pad: Int = 0 + private var bucketKey: AnyRef = null + private var providedData: ListMap[String, Shape] = ListMap.empty + private var providedLabel: ListMap[String, Shape] = ListMap.empty + + @varargs def setData(data: NDArray*): Builder = { + this.data = data.toIndexedSeq + this + } + + @varargs def setLabel(label: NDArray*): Builder = { + this.label = label.toIndexedSeq + this + } + + @varargs def setIndex(index: Long*): Builder = { + this.index = index.toIndexedSeq + this + } + + def setPad(pad: Int): Builder = { + this.pad = pad + this + } + + def setBucketKey(bucketKey: AnyRef): Builder = { + this.bucketKey = bucketKey + this + } + + def provideData(name: String, shape: Shape): Builder = { + providedData = providedData.updated(name, shape) + this + } + + def provideLabel(name: String, shape: Shape): Builder = { + providedLabel = providedLabel.updated(name, shape) + this + } + + def build(): DataBatch = { + new DataBatch(data, label, index, pad, + bucketKey, providedData, providedLabel) + } + } +} + /** * DataIter object in mxnet. */ diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala index 416f2d7..e8c687e 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala @@ -48,6 +48,7 @@ object NDArray { } } + // private[mxnet] def genericNDArrayFunctionInvoke( /** * Used by NDArrayMacro. * Invoke this function by passing in parameters. @@ -57,7 +58,7 @@ object NDArray { * @param kwargs Key-value arguments of input scalars * @return The result NDArrays of result of computation. */ - private[mxnet] def genericNDArrayFunctionInvoke( + def genericNDArrayFunctionInvoke( funcName: String, args: Seq[Any], kwargs: Map[String, Any] = null): NDArrayFuncReturn = { val function = functions(funcName) val ndArgs = ArrayBuffer.empty[NDArray] diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala index e632ade..6891762 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala @@ -17,6 +17,8 @@ package org.apache.mxnet +import scala.annotation.varargs + /** * Shape of [[NDArray]] or other data */ @@ -28,6 +30,7 @@ class Shape(dims: Traversable[Int]) extends Serializable { } def apply(dim: Int): Int = shape(dim) + def get(dim: Int): Int = apply(dim) def size: Int = shape.size def length: Int = shape.length def drop(dim: Int): Shape = new Shape(shape.drop(dim)) @@ -56,4 +59,5 @@ class Shape(dims: Traversable[Int]) extends Serializable { object Shape { def apply(dims: Int *): Shape = new Shape(dims: _*) def apply(dims: Traversable[Int]): Shape = new Shape(dims) + @varargs def create(dims: Int*): Shape = new Shape(dims) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 13f85a7..b6947b4 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -101,7 +101,6 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD var index: Int = -1 for ((output, i) <- listOutputs().view.zipWithIndex) { if (output == name) { - require(index == -1, s"There are multiple outputs with name $name") index = i } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala index 108cff4..f7ae883 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala @@ -23,6 +23,8 @@ import org.apache.mxnet.optimizer.SGD import org.apache.mxnet._ import org.slf4j.LoggerFactory import org.slf4j.Logger + +import scala.annotation.varargs import scala.collection.mutable.ArrayBuffer object BaseModule { @@ -468,6 +470,10 @@ abstract class BaseModule { */ def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit + def forward(dataBatch: DataBatch, isTrain: Boolean): Unit = { + forward(dataBatch, Option(isTrain)) + } + /** * Backward computation. * @param outGrads Gradient on the outputs to be propagated back. @@ -549,6 +555,30 @@ abstract class BaseModule { forceRebind: Boolean = false, sharedModule: Option[BaseModule] = None, gradReq: String = "write"): Unit + + protected var labelShapesPartial: IndexedSeq[DataDesc] = _ + protected var sharedModulePartial: BaseModule = _ + protected var gradReqPartial: String = "write" + @varargs def bindPartial(labelShape: DataDesc*): BaseModule = { + labelShapesPartial = labelShape.toIndexedSeq + this + } + def bindPartial(sharedModule: BaseModule): BaseModule = { + sharedModulePartial = sharedModule + this + } + def bindPartial(gradReq: String): BaseModule = { + gradReqPartial = gradReq + this + } + + @varargs def bind(forTraining: Boolean, inputsNeedGrad: Boolean, + forceRebind: Boolean, dataShape: DataDesc*): Unit = { + bind(dataShape.toVector, Option(labelShapesPartial), + forTraining, inputsNeedGrad, forceRebind, + Option(sharedModulePartial), gradReqPartial) + } + // Install and initialize optimizers. def initOptimizer(kvstore: String = "local", optimizer: Optimizer = new SGD(), resetOptimizer: Boolean = true, forceInit: Boolean = false): Unit diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala index ac3d645..a46b605 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala @@ -17,13 +17,16 @@ package org.apache.mxnet.module -import java.io.{FileInputStream, BufferedInputStream, BufferedOutputStream, FileOutputStream} +import java.io.{BufferedInputStream, BufferedOutputStream, FileInputStream, FileOutputStream} + import org.apache.mxnet.DType.DType import org.apache.mxnet._ import org.apache.mxnet.module.DataParallelExecutorGroup.Builder import org.apache.mxnet.optimizer.SGD import org.slf4j.LoggerFactory +import scala.annotation.varargs + /** * Module is a basic module that wrap a `Symbol`. It is functionally the same * as the `FeedForward` model, except under the module API. @@ -642,4 +645,42 @@ object Module { } mod } + + class Builder (private val modelDef: Symbol) { + private var dataNames: IndexedSeq[String] = IndexedSeq("data") + private var labelNames: IndexedSeq[String] = IndexedSeq("softmax_label") + private var contexts: Array[Context] = Array(Context.cpu()) + private var workLoadList: IndexedSeq[Float] = _ + private var fixedParamNames: Set[String] = _ + + @varargs def setContext(ctx: Context*): Builder = { + contexts = ctx.toArray + this + } + + @varargs def setDataNames(name: String*): Builder = { + dataNames = name.toVector + this + } + + @varargs def setLabelNames(name: String*): Builder = { + labelNames = name.toVector + this + } + + @varargs def setWorkLoadList(workload: Float*): Builder = { + workLoadList = workload.toVector + this + } + + @varargs def setFixedParamNames(name: String*): Builder = { + fixedParamNames = name.toSet + this + } + + def build(): Module = { + new Module(modelDef, dataNames, labelNames, contexts, + Option(workLoadList), Option(fixedParamNames)) + } + } } diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index c26d14c..c4d16bc 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -52,18 +52,6 @@ private[mxnet] object NDArrayMacro { else ndarrayFunctions.filter(!_._1.startsWith("_contrib_")) } - val AST_NDARRAY_TYPE = Select(Select(Select( - Ident(TermName("org")), TermName("apache")), TermName("mxnet")), TypeName("NDArray")) - val AST_TYPE_MAP_STRING_ANY = AppliedTypeTree(Ident(TypeName("Map")), - List(Ident(TypeName("String")), Ident(TypeName("Any")))) - val AST_TYPE_ANY_VARARG = AppliedTypeTree( - Select( - Select(Ident(termNames.ROOTPKG), TermName("scala")), - TypeName("<repeated>") - ), - List(Ident(TypeName("Any"))) - ) - val functionDefs = newNDArrayFunctions flatMap { case (funcName, funcProp) => val functionScope = { if (isContrib) Modifiers() @@ -75,45 +63,15 @@ private[mxnet] object NDArrayMacro { if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + "_contrib_".length()) else funcName } - + val termName = TermName(funcName) // It will generate definition something like, Seq( + // scalastyle:off // def transpose(kwargs: Map[String, Any] = null)(args: Any*) - DefDef(functionScope, TermName(newName), List(), - List( - List( - ValDef(Modifiers(Flag.PARAM | Flag.DEFAULTPARAM), TermName("kwargs"), - AST_TYPE_MAP_STRING_ANY, Literal(Constant(null))) - ), - List( - ValDef(Modifiers(), TermName("args"), AST_TYPE_ANY_VARARG, EmptyTree) - ) - ), TypeTree(), - Apply( - Ident(TermName("genericNDArrayFunctionInvoke")), - List( - Literal(Constant(funcName)), - Ident(TermName("args")), - Ident(TermName("kwargs")) - ) - ) - ), + q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}", // def transpose(args: Any*) - DefDef(functionScope, TermName(newName), List(), - List( - List( - ValDef(Modifiers(), TermName("args"), AST_TYPE_ANY_VARARG, EmptyTree) - ) - ), TypeTree(), - Apply( - Ident(TermName("genericNDArrayFunctionInvoke")), - List( - Literal(Constant(funcName)), - Ident(TermName("args")), - Literal(Constant(null)) - ) - ) - ) + q"@scala.annotation.varargs def $termName(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, null)}" + // scalastyle:on ) } -- To stop receiving notification emails like this one, please contact liuyi...@apache.org.