reminisce commented on a change in pull request #16100: Infra for tvm op 
runtime dispatch
URL: https://github.com/apache/incubator-mxnet/pull/16100#discussion_r325491433
 
 

 ##########
 File path: contrib/tvmop/compile.py
 ##########
 @@ -37,23 +40,57 @@ def get_target(device):
     parser = argparse.ArgumentParser(description="Generate tvm operators")
     parser.add_argument("-o", action="store", required=True, 
dest="target_path",
                         help="Target path which stores compiled library")
+    parser.add_argument("--config", action="store", required=True, 
dest="config_path",
+                        help="Path which stores the config file")
     arguments = parser.parse_args()
 
     func_list_llvm = []
     func_list_cuda = []
+    config_spaces = ConfigSpaces()
 
     # TODO: attach instruction features to the library, e.g., avx-512, etc.
-    for operator_def in __OP_DEF__:
-        for sch, args, name in operator_def.invoke_all():
-            if tvm.module.enabled(get_target(operator_def.target)):
-                func_list = func_list_llvm if operator_def.target == "cpu" 
else func_list_cuda
-                func_lower = tvm.lower(sch, args,
-                                       name=name,
-                                       binds=operator_def.get_binds(args))
-                func_list.append(func_lower)
+    for op in __OP_DEF__:
+        if tvm.module.enabled(get_target(op.target)):
+            func_list = func_list_llvm if op.target == "cpu" else 
func_list_cuda
+            for each_kwargs in op.arg_combination:
+                if (op.attrs_valid(**each_kwargs)):
+                    name = op.name \
+                        + ''.join(["{}_{}".format(key, each_kwargs[key]) for 
key in op.attrs])
+                    if op.dispatch is True:
+                        config_space = autotvm.ConfigSpace()
+                        with autotvm.task.ApplyConfig(config_space):
+                            sch, args = op.func(fallback=False, **each_kwargs)
+                        # register dispatch schedules
+                        for i in range(len(config_space)):
+                            config_entity = config_space.get(i)
+                            with autotvm.task.ApplyConfig(config_entity):
+                                sch, args = op.func(fallback=False, 
**each_kwargs)
+                            subname = name + "index_" + str(i) + \
+                                ''.join(["%s_%d" % (arg.dtype, len(arg.shape)) 
for arg in args])
+                            func_lower = tvm.lower(sch, args,
+                                                name=subname,
 
 Review comment:
   fix indent

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to