lanking520 commented on a change in pull request #10660: [MXNET-357] New Scala API Design (Symbol) URL: https://github.com/apache/incubator-mxnet/pull/10660#discussion_r185635887
########## File path: scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala ########## @@ -21,94 +21,101 @@ import scala.annotation.StaticAnnotation import scala.collection.mutable.ListBuffer import scala.language.experimental.macros import scala.reflect.macros.blackbox - import org.apache.mxnet.init.Base._ import org.apache.mxnet.utils.OperatorBuildUtils private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends StaticAnnotation { private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addDefs } +private[mxnet] class AddNewSymbolFunctions(isContrib: Boolean) extends StaticAnnotation { + private[mxnet] def macroTransform(annottees: Any*) = macro SymbolImplMacros.addNewDefs +} + private[mxnet] object SymbolImplMacros { - case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String) + case class SymbolArg(argName: String, argType: String, isOptional : Boolean) + case class SymbolFunction(name: String, listOfArgs: List[SymbolArg]) // scalastyle:off havetype def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { - impl(c)(false, annottees: _*) + impl(c)(false, false, annottees: _*) } - // scalastyle:off havetype + def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = { + impl(c)(false, true, annottees: _*) + } + // scalastyle:on havetype - private val symbolFunctions: Map[String, SymbolFunction] = initSymbolModule() + private val symbolFunctions: List[SymbolFunction] = initSymbolModule() - private def impl(c: blackbox.Context)(addSuper: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = { + private def impl(c: blackbox.Context)(addSuper: Boolean, + newAPI: Boolean, annottees: c.Expr[Any]*): c.Expr[Any] = { import c.universe._ val isContrib: Boolean = c.prefix.tree match { case q"new AddSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b)) + case q"new AddNewSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b)) } val newSymbolFunctions = { - if (isContrib) symbolFunctions.filter(_._1.startsWith("_contrib_")) - else symbolFunctions.filter(!_._1.startsWith("_contrib_")) + if (isContrib) symbolFunctions.filter(_.name.startsWith("_contrib_")) + else symbolFunctions.filter(!_.name.startsWith("_contrib_")) } - val AST_TYPE_MAP_STRING_ANY = AppliedTypeTree(Ident(TypeName("Map")), - List(Ident(TypeName("String")), Ident(TypeName("Any")))) - val AST_TYPE_MAP_STRING_STRING = AppliedTypeTree(Ident(TypeName("Map")), - List(Ident(TypeName("String")), Ident(TypeName("String")))) - val AST_TYPE_SYMBOL_VARARG = AppliedTypeTree( - Select( - Select(Ident(termNames.ROOTPKG), TermName("scala")), - TypeName("<repeated>") - ), - List(Select(Select(Select( - Ident(TermName("org")), TermName("apache")), TermName("mxnet")), TypeName("Symbol"))) - ) - - val functionDefs = newSymbolFunctions map { case (funcName, funcProp) => - val functionScope = { - if (isContrib) Modifiers() - else { - if (funcName.startsWith("_")) Modifiers(Flag.PRIVATE) else Modifiers() - } + var functionDefs = List[DefDef]() + + if (!newAPI) { + functionDefs = newSymbolFunctions map { symbolfunction => + val funcName = symbolfunction.name + val tName = TermName(funcName) + q""" + def $tName(name : String = null, attr : Map[String, String] = null) + (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null) + : org.apache.mxnet.Symbol = { + createSymbolGeneral($funcName,name,attr,args,kwargs) + } + """.asInstanceOf[DefDef] } - val newName = { - if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + "_contrib_".length()) - else funcName + } else { + functionDefs = newSymbolFunctions map { symbolfunction => + + // Construct argument field + var argDef = ListBuffer[String]() + symbolfunction.listOfArgs.foreach(symbolarg => { + val currArgName = if (symbolarg.argName.equals("var")) "vari" else symbolarg.argName + if (symbolarg.isOptional) { + argDef += s"${currArgName} : Option[${symbolarg.argType}] = None" Review comment: If we pass them as None, the arg will go with their default value in C ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services