mbs-octoml commented on a change in pull request #8802: URL: https://github.com/apache/tvm/pull/8802#discussion_r692530723
########## File path: src/relay/backend/aot_executor_codegen.cc ########## @@ -48,6 +48,7 @@ namespace backend { using IntegerArray = Array<Integer>; using StorageMap = std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>; +using namespace tec; Review comment: (stylistic question) It looks like this is an established pattern in the code but the style guide discourages this: https://google.github.io/styleguide/cppguide.html#Namespaces ########## File path: python/tvm/relay/backend/graph_executor_codegen.py ########## @@ -53,7 +53,7 @@ def __init__(self, mod, target): self._get_irmodule = self._mod["get_irmodule"] self._setup(mod, target) - def _setup(self, mod, target): + def _setup(self, mod, target: Dict[int, Target]): Review comment: This type is stronger than the dynamic check below, perhaps you put it in while you were groking the code? ########## File path: src/relay/backend/graph_executor_codegen.cc ########## @@ -55,6 +55,7 @@ using GraphAttrs = std::unordered_map<std::string, dmlc::any>; using GraphObjectPtr = std::shared_ptr<GraphNode>; using GraphInputObjectPtr = std::shared_ptr<GraphInputNode>; using GraphOpObjectPtr = std::shared_ptr<GraphOpNode>; +using namespace tec; Review comment: ditto ########## File path: src/relay/backend/te_compiler.cc ########## @@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic return lowered_module; } +IRModule LoweredModuleToIRModule(LoweredModule mod) { + Map<GlobalVar, BaseFunc> unified_funcs; + Map<GlobalTypeVar, TypeData> unified_type_defs; + + // copy main module funcs to unified funcs (what target do we need to annotate with here?) + for (const auto& kv : mod.main_module->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + ICHECK(!func->IsInstance<tir::PrimFuncNode>()); + unified_funcs.Set(var, func); + } + + // copy the type definitions for the main module + for (const auto& kv : mod.main_module->type_definitions) { + const GlobalTypeVar& ty_var = kv.first; + const TypeData& ty_data = kv.second; + unified_type_defs.Set(ty_var, ty_data); + } + // Move functions in per target IRModule into unified module + // Also move the type definitions + for (const auto& kv : mod.per_target_module) { + const String target = kv.first; + const IRModule target_module = kv.second; + // Move the per module functions, and annotate the funcs with their target + for (const auto& kv : target_module->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + ICHECK(func->IsInstance<tir::PrimFuncNode>()) + << "We expect the target_module to contain only PrimFuncs at this point, but got " + << func->GetTypeKey(); + tir::PrimFunc primFunc = WithAttr(Downcast<tir::PrimFunc>(std::move(func)), attr::kTarget, + runtime::String(target)); + unified_funcs.Set(var, primFunc); + } + + // Move the type definitions for the per target IRModule + for (const auto& kv : target_module->type_definitions) { + const GlobalTypeVar& ty_var = kv.first; + const TypeData& ty_data = kv.second; + unified_type_defs.Set(ty_var, ty_data); + } + } + + IRModule ret_mod = + WithAttr(IRModule(unified_funcs, unified_type_defs), "external_mods", mod.external_mods); + ret_mod = WithAttr(ret_mod, "main_func_info", mod.main_func_info); + return ret_mod; +} + +LoweredModule IRModuleToLoweredModule(IRModule mod) { + Map<GlobalVar, BaseFunc> main_mod_funcs; + Map<String, Map<GlobalVar, BaseFunc>> target_funcs; + for (const auto& kv : mod->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + if (func->IsInstance<relay::FunctionNode>()) { + main_mod_funcs.Set(var, func); + } else if (func->IsInstance<tir::PrimFuncNode>()) { + // Extract target + auto target = func->GetAttr<String>(attr::kTarget); + ICHECK(!target) << "Target should be set at this point"; + + // Put the function in target_funcs + if (!target_funcs.count(target.value())) { + // Initialize the map and put it in target_funcs + Map<GlobalVar, BaseFunc> funcs; + funcs.Set(var, func); + target_funcs.Set(target.value(), funcs); + + } else { + // The map is initialized, so just add the function. + Map<GlobalVar, BaseFunc> funcs = target_funcs.at(target.value()); + funcs.Set(var, func); + } + } else { + LOG(FATAL) + << "The function types in the IRModule should be RelayFunction or PrimFunc, but got " + << func->GetTypeKey(); + } + } + // Create the per_target_module map + Map<String, IRModule> per_target_modules; + for (const auto& kv : target_funcs) { + String target = kv.first; + Map<GlobalVar, BaseFunc> funcs = kv.second; + // Here, we just copy the type defs to every module. Since TIR doesn't use the type defs, + // this duplication should be OK. + per_target_modules.Set(target, IRModule(funcs, mod->type_definitions)); + } + LoweredModule lowered_module; + lowered_module.main_module = IRModule(main_mod_funcs, mod->type_definitions); + lowered_module.per_target_module = per_target_modules; + + // Extract external modules and main func info, add to lowered module if they exist + auto external_mods = mod->GetAttr<Array<tvm::runtime::Module>>("external_mods"); + if (external_mods) { + lowered_module.external_mods = external_mods.value(); + } + auto main_func_info = mod->GetAttr<backend::FunctionInfo>("main_func_info"); + if (main_func_info) { + lowered_module.main_func_info = main_func_info.value(); + } + return lowered_module; +} + +Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map, + backend::StaticMemoryPlan memory_plan, const String& module_name, + std::function<void(Function)> process_fn) { + runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module, + PassContext ctx) { + return LoweredModuleToIRModule( + LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn)); + }; + return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {}); +} Review comment: [A question to the community and not specific to your PR Lily!] This is a good example of code which could be easily unit tested in C++ in the, er, 'conventional' sense. That is, as a reader I could expect to go to tests/cpp/relay/backend/te_compiler_test.cc and look for TEST(IRModuleToLoweredModule, ...). Currently this new code is tested indirectly via it's use by LowerTEPass and consumers of such, which in turn are tested indirectly by virtue of everything passing into TIR via this choke point. Just wanted to test the water on whether folks on this PR have opinions here so I don't go off tilting at windmills. ########## File path: src/relay/backend/te_compiler.h ########## @@ -184,12 +206,15 @@ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets); * \param device_map An analysis result mapping each sub-expression to a device. * \return The lowered module, see above. */ -// TODO(@electriclilies): Not sure if this default initialization is correct... LoweredModule LowerTE( const IRModule& module, TargetMap targets, DeviceMap device_map, backend::StaticMemoryPlan memory_plan, const String& module_name, ProcessFn process_fn = [](Function f) {}); +using namespace transform; Review comment: unqualified using ########## File path: src/relay/backend/te_compiler.cc ########## @@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic return lowered_module; } +IRModule LoweredModuleToIRModule(LoweredModule mod) { + Map<GlobalVar, BaseFunc> unified_funcs; + Map<GlobalTypeVar, TypeData> unified_type_defs; + + // copy main module funcs to unified funcs (what target do we need to annotate with here?) Review comment: It might be less ceremony to just .Add directly into an IRModule result? ########## File path: src/relay/backend/te_compiler.cc ########## @@ -593,6 +593,14 @@ Pass LowerTensorExpr(TargetMap targets, DeviceMap device_context_map, return CreateFunctionPass(pass_func, 0, "LowerTensorExpr", {}); } +/*! + * \brief Obtain the Target from the device type. + * If homogenous compilation, this will return the only target. + * If heteregenous compilation, this will select associated using the targets_ Map. + * + * \param dev_type + * \return Target + */ Review comment: I'm all for more documentation but since this is declared and commented in the header the DNRY rule applies :-) ########## File path: src/relay/backend/te_compiler.cc ########## @@ -875,6 +878,121 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic return lowered_module; } +IRModule LoweredModuleToIRModule(LoweredModule mod) { + Map<GlobalVar, BaseFunc> unified_funcs; + Map<GlobalTypeVar, TypeData> unified_type_defs; + + // copy main module funcs to unified funcs (what target do we need to annotate with here?) + for (const auto& kv : mod.main_module->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + ICHECK(!func->IsInstance<tir::PrimFuncNode>()); + unified_funcs.Set(var, func); + } + + // copy the type definitions for the main module + for (const auto& kv : mod.main_module->type_definitions) { + const GlobalTypeVar& ty_var = kv.first; + const TypeData& ty_data = kv.second; + unified_type_defs.Set(ty_var, ty_data); + } + // Move functions in per target IRModule into unified module + // Also move the type definitions + for (const auto& kv : mod.per_target_module) { + const String target = kv.first; + const IRModule target_module = kv.second; + // Move the per module functions, and annotate the funcs with their target + for (const auto& kv : target_module->functions) { + const GlobalVar& var = kv.first; + const BaseFunc& func = kv.second; + ICHECK(func->IsInstance<tir::PrimFuncNode>()) + << "We expect the target_module to contain only PrimFuncs at this point, but got " + << func->GetTypeKey(); + tir::PrimFunc primFunc = WithAttr(Downcast<tir::PrimFunc>(std::move(func)), attr::kTarget, + runtime::String(target)); + unified_funcs.Set(var, primFunc); + } + + // Move the type definitions for the per target IRModule Review comment: Hmm, this part is not reversible, so what about if we ICHECK there are no typedefs in the lowered. I think that should be the case? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org