lanking520 commented on a change in pull request #10660: [MXNET-357] New Scala 
API Design (Symbol)
URL: https://github.com/apache/incubator-mxnet/pull/10660#discussion_r185635887
 
 

 ##########
 File path: 
scala-package/macros/src/main/scala/org/apache/mxnet/SymbolMacro.scala
 ##########
 @@ -21,94 +21,101 @@ import scala.annotation.StaticAnnotation
 import scala.collection.mutable.ListBuffer
 import scala.language.experimental.macros
 import scala.reflect.macros.blackbox
-
 import org.apache.mxnet.init.Base._
 import org.apache.mxnet.utils.OperatorBuildUtils
 
 private[mxnet] class AddSymbolFunctions(isContrib: Boolean) extends 
StaticAnnotation {
   private[mxnet] def macroTransform(annottees: Any*) = macro 
SymbolImplMacros.addDefs
 }
 
+private[mxnet] class AddNewSymbolFunctions(isContrib: Boolean) extends 
StaticAnnotation {
+  private[mxnet] def macroTransform(annottees: Any*) = macro 
SymbolImplMacros.addNewDefs
+}
+
 private[mxnet] object SymbolImplMacros {
-  case class SymbolFunction(handle: SymbolHandle, keyVarNumArgs: String)
+  case class SymbolArg(argName: String, argType: String, isOptional : Boolean)
+  case class SymbolFunction(name: String, listOfArgs: List[SymbolArg])
 
   // scalastyle:off havetype
   def addDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
-    impl(c)(false, annottees: _*)
+    impl(c)(false, false, annottees: _*)
   }
-  // scalastyle:off havetype
+  def addNewDefs(c: blackbox.Context)(annottees: c.Expr[Any]*) = {
+    impl(c)(false, true, annottees: _*)
+  }
+  // scalastyle:on havetype
 
-  private val symbolFunctions: Map[String, SymbolFunction] = initSymbolModule()
+  private val symbolFunctions: List[SymbolFunction] = initSymbolModule()
 
-  private def impl(c: blackbox.Context)(addSuper: Boolean, annottees: 
c.Expr[Any]*): c.Expr[Any] = {
+  private def impl(c: blackbox.Context)(addSuper: Boolean,
+                                        newAPI: Boolean, annottees: 
c.Expr[Any]*): c.Expr[Any] = {
     import c.universe._
 
     val isContrib: Boolean = c.prefix.tree match {
       case q"new AddSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b))
+      case q"new AddNewSymbolFunctions($b)" => c.eval[Boolean](c.Expr(b))
     }
 
     val newSymbolFunctions = {
-      if (isContrib) symbolFunctions.filter(_._1.startsWith("_contrib_"))
-      else symbolFunctions.filter(!_._1.startsWith("_contrib_"))
+      if (isContrib) symbolFunctions.filter(_.name.startsWith("_contrib_"))
+      else symbolFunctions.filter(!_.name.startsWith("_contrib_"))
     }
 
-    val AST_TYPE_MAP_STRING_ANY = AppliedTypeTree(Ident(TypeName("Map")),
-      List(Ident(TypeName("String")), Ident(TypeName("Any"))))
-    val AST_TYPE_MAP_STRING_STRING = AppliedTypeTree(Ident(TypeName("Map")),
-      List(Ident(TypeName("String")), Ident(TypeName("String"))))
-    val AST_TYPE_SYMBOL_VARARG = AppliedTypeTree(
-      Select(
-        Select(Ident(termNames.ROOTPKG), TermName("scala")),
-        TypeName("<repeated>")
-      ),
-      List(Select(Select(Select(
-        Ident(TermName("org")), TermName("apache")), TermName("mxnet")), 
TypeName("Symbol")))
-    )
-
-    val functionDefs = newSymbolFunctions map { case (funcName, funcProp) =>
-      val functionScope = {
-        if (isContrib) Modifiers()
-        else {
-          if (funcName.startsWith("_")) Modifiers(Flag.PRIVATE) else 
Modifiers()
-        }
+    var functionDefs = List[DefDef]()
+
+    if (!newAPI) {
+      functionDefs = newSymbolFunctions map { symbolfunction =>
+        val funcName = symbolfunction.name
+        val tName = TermName(funcName)
+        q"""
+            def $tName(name : String = null, attr : Map[String, String] = null)
+            (args : org.apache.mxnet.Symbol*)(kwargs : Map[String, Any] = null)
+             : org.apache.mxnet.Symbol = {
+              createSymbolGeneral($funcName,name,attr,args,kwargs)
+              }
+         """.asInstanceOf[DefDef]
       }
-      val newName = {
-        if (isContrib) funcName.substring(funcName.indexOf("_contrib_") + 
"_contrib_".length())
-        else funcName
+    } else {
+      functionDefs = newSymbolFunctions map { symbolfunction =>
+
+        // Construct argument field
+        var argDef = ListBuffer[String]()
+        symbolfunction.listOfArgs.foreach(symbolarg => {
+          val currArgName = if (symbolarg.argName.equals("var")) "vari" else 
symbolarg.argName
+          if (symbolarg.isOptional) {
+            argDef += s"${currArgName} : Option[${symbolarg.argType}] = None"
 
 Review comment:
   If we pass them as None, the arg will go with their default value in C

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to