mbs-octoml commented on a change in pull request #8802:
URL: https://github.com/apache/tvm/pull/8802#discussion_r692530723



##########
File path: src/relay/backend/aot_executor_codegen.cc
##########
@@ -48,6 +48,7 @@ namespace backend {
 using IntegerArray = Array<Integer>;
 using StorageMap =
     std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, 
runtime::ObjectPtrEqual>;
+using namespace tec;

Review comment:
       (stylistic question) It looks like this is an established pattern in the 
code but the style guide discourages this:
   https://google.github.io/styleguide/cppguide.html#Namespaces

##########
File path: python/tvm/relay/backend/graph_executor_codegen.py
##########
@@ -53,7 +53,7 @@ def __init__(self, mod, target):
         self._get_irmodule = self._mod["get_irmodule"]
         self._setup(mod, target)
 
-    def _setup(self, mod, target):
+    def _setup(self, mod, target: Dict[int, Target]):

Review comment:
       This type is stronger than the dynamic check below, perhaps you put it 
in while you were groking the code?

##########
File path: src/relay/backend/graph_executor_codegen.cc
##########
@@ -55,6 +55,7 @@ using GraphAttrs = std::unordered_map<std::string, dmlc::any>;
 using GraphObjectPtr = std::shared_ptr<GraphNode>;
 using GraphInputObjectPtr = std::shared_ptr<GraphInputNode>;
 using GraphOpObjectPtr = std::shared_ptr<GraphOpNode>;
+using namespace tec;

Review comment:
       ditto

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap 
targets, DeviceMap devic
   return lowered_module;
 }
 
+IRModule LoweredModuleToIRModule(LoweredModule mod) {
+  Map<GlobalVar, BaseFunc> unified_funcs;
+  Map<GlobalTypeVar, TypeData> unified_type_defs;
+
+  // copy main module funcs to unified funcs (what target do we need to 
annotate with here?)
+  for (const auto& kv : mod.main_module->functions) {
+    const GlobalVar& var = kv.first;
+    const BaseFunc& func = kv.second;
+    ICHECK(!func->IsInstance<tir::PrimFuncNode>());
+    unified_funcs.Set(var, func);
+  }
+
+  // copy the type definitions for the main module
+  for (const auto& kv : mod.main_module->type_definitions) {
+    const GlobalTypeVar& ty_var = kv.first;
+    const TypeData& ty_data = kv.second;
+    unified_type_defs.Set(ty_var, ty_data);
+  }
+  // Move functions in per target IRModule into unified module
+  // Also move the type definitions
+  for (const auto& kv : mod.per_target_module) {
+    const String target = kv.first;
+    const IRModule target_module = kv.second;
+    // Move the per module functions, and annotate the funcs with their target
+    for (const auto& kv : target_module->functions) {
+      const GlobalVar& var = kv.first;
+      const BaseFunc& func = kv.second;
+      ICHECK(func->IsInstance<tir::PrimFuncNode>())
+          << "We expect the target_module to contain only PrimFuncs at this 
point, but got "
+          << func->GetTypeKey();
+      tir::PrimFunc primFunc = 
WithAttr(Downcast<tir::PrimFunc>(std::move(func)), attr::kTarget,
+                                        runtime::String(target));
+      unified_funcs.Set(var, primFunc);
+    }
+
+    // Move the type definitions for the per target IRModule
+    for (const auto& kv : target_module->type_definitions) {
+      const GlobalTypeVar& ty_var = kv.first;
+      const TypeData& ty_data = kv.second;
+      unified_type_defs.Set(ty_var, ty_data);
+    }
+  }
+
+  IRModule ret_mod =
+      WithAttr(IRModule(unified_funcs, unified_type_defs), "external_mods", 
mod.external_mods);
+  ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info);
+  return ret_mod;
+}
+
+LoweredModule IRModuleToLoweredModule(IRModule mod) {
+  Map<GlobalVar, BaseFunc> main_mod_funcs;
+  Map<String, Map<GlobalVar, BaseFunc>> target_funcs;
+  for (const auto& kv : mod->functions) {
+    const GlobalVar& var = kv.first;
+    const BaseFunc& func = kv.second;
+    if (func->IsInstance<relay::FunctionNode>()) {
+      main_mod_funcs.Set(var, func);
+    } else if (func->IsInstance<tir::PrimFuncNode>()) {
+      // Extract target
+      auto target = func->GetAttr<String>(attr::kTarget);
+      ICHECK(!target) << "Target should be set at this point";
+
+      // Put the function in target_funcs
+      if (!target_funcs.count(target.value())) {
+        // Initialize the map and put it in target_funcs
+        Map<GlobalVar, BaseFunc> funcs;
+        funcs.Set(var, func);
+        target_funcs.Set(target.value(), funcs);
+
+      } else {
+        // The map is initialized, so just add the function.
+        Map<GlobalVar, BaseFunc> funcs = target_funcs.at(target.value());
+        funcs.Set(var, func);
+      }
+    } else {
+      LOG(FATAL)
+          << "The function types in the IRModule should be RelayFunction or 
PrimFunc, but got "
+          << func->GetTypeKey();
+    }
+  }
+  // Create the per_target_module map
+  Map<String, IRModule> per_target_modules;
+  for (const auto& kv : target_funcs) {
+    String target = kv.first;
+    Map<GlobalVar, BaseFunc> funcs = kv.second;
+    // Here, we just copy the type defs to every module. Since TIR doesn't use 
the type defs,
+    // this duplication should be OK.
+    per_target_modules.Set(target, IRModule(funcs, mod->type_definitions));
+  }
+  LoweredModule lowered_module;
+  lowered_module.main_module = IRModule(main_mod_funcs, mod->type_definitions);
+  lowered_module.per_target_module = per_target_modules;
+
+  // Extract external modules and main func info, add to lowered module if 
they exist
+  auto external_mods = 
mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
+  if (external_mods) {
+    lowered_module.external_mods = external_mods.value();
+  }
+  auto main_func_info = mod->GetAttr<backend::FunctionInfo>("main_func_info");
+  if (main_func_info) {
+    lowered_module.main_func_info = main_func_info.value();
+  }
+  return lowered_module;
+}
+
+Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
+                 backend::StaticMemoryPlan memory_plan, const String& 
module_name,
+                 std::function<void(Function)> process_fn) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule module,
+                                                                            
PassContext ctx) {
+    return LoweredModuleToIRModule(
+        LowerTE(module, targets, device_context_map, memory_plan, module_name, 
process_fn));
+  };
+  return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
+}

Review comment:
       [A question to the community and not specific to your PR Lily!]
   
   This is a good example of code which could be easily unit tested in C++ in 
the, er, 'conventional' sense. That is, as a reader I could expect to go to 
tests/cpp/relay/backend/te_compiler_test.cc and look for 
TEST(IRModuleToLoweredModule, ...). Currently this new code is tested 
indirectly via it's use by LowerTEPass and consumers of such, which in turn are 
tested indirectly by virtue of everything passing into TIR via this choke 
point. Just wanted to test the water on whether folks on this PR have opinions 
here so I don't go off tilting at windmills.

##########
File path: src/relay/backend/te_compiler.h
##########
@@ -184,12 +206,15 @@ Target GetTargetFromInteger(DLDeviceType dev_type, 
TargetMap targets);
  * \param device_map An analysis result mapping each sub-expression to a 
device.
  * \return The lowered module, see above.
  */
-// TODO(@electriclilies): Not sure if this default initialization is correct...
 LoweredModule LowerTE(
     const IRModule& module, TargetMap targets, DeviceMap device_map,
     backend::StaticMemoryPlan memory_plan, const String& module_name,
     ProcessFn process_fn = [](Function f) {});
 
+using namespace transform;

Review comment:
       unqualified using

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap 
targets, DeviceMap devic
   return lowered_module;
 }
 
+IRModule LoweredModuleToIRModule(LoweredModule mod) {
+  Map<GlobalVar, BaseFunc> unified_funcs;
+  Map<GlobalTypeVar, TypeData> unified_type_defs;
+
+  // copy main module funcs to unified funcs (what target do we need to 
annotate with here?)

Review comment:
       It might be less ceremony to just .Add directly into an IRModule result?

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -593,6 +593,14 @@ Pass LowerTensorExpr(TargetMap targets, DeviceMap 
device_context_map,
   return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {});
 }
 
+/*!
+ * \brief Obtain the Target from the device type.
+ * If homogenous compilation, this will return the only target.
+ * If heteregenous compilation, this will select associated using the targets_ 
Map.
+ *
+ * \param dev_type
+ * \return Target
+ */

Review comment:
       I'm all for more documentation but since this is declared and commented 
in the header the DNRY rule applies :-)

##########
File path: src/relay/backend/te_compiler.cc
##########
@@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap 
targets, DeviceMap devic
   return lowered_module;
 }
 
+IRModule LoweredModuleToIRModule(LoweredModule mod) {
+  Map<GlobalVar, BaseFunc> unified_funcs;
+  Map<GlobalTypeVar, TypeData> unified_type_defs;
+
+  // copy main module funcs to unified funcs (what target do we need to 
annotate with here?)
+  for (const auto& kv : mod.main_module->functions) {
+    const GlobalVar& var = kv.first;
+    const BaseFunc& func = kv.second;
+    ICHECK(!func->IsInstance<tir::PrimFuncNode>());
+    unified_funcs.Set(var, func);
+  }
+
+  // copy the type definitions for the main module
+  for (const auto& kv : mod.main_module->type_definitions) {
+    const GlobalTypeVar& ty_var = kv.first;
+    const TypeData& ty_data = kv.second;
+    unified_type_defs.Set(ty_var, ty_data);
+  }
+  // Move functions in per target IRModule into unified module
+  // Also move the type definitions
+  for (const auto& kv : mod.per_target_module) {
+    const String target = kv.first;
+    const IRModule target_module = kv.second;
+    // Move the per module functions, and annotate the funcs with their target
+    for (const auto& kv : target_module->functions) {
+      const GlobalVar& var = kv.first;
+      const BaseFunc& func = kv.second;
+      ICHECK(func->IsInstance<tir::PrimFuncNode>())
+          << "We expect the target_module to contain only PrimFuncs at this 
point, but got "
+          << func->GetTypeKey();
+      tir::PrimFunc primFunc = 
WithAttr(Downcast<tir::PrimFunc>(std::move(func)), attr::kTarget,
+                                        runtime::String(target));
+      unified_funcs.Set(var, primFunc);
+    }
+
+    // Move the type definitions for the per target IRModule

Review comment:
       Hmm, this part is not reversible, so what about if we ICHECK there are 
no typedefs in the lowered. I think that should be the case?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to