This is an automated email from the ASF dual-hosted git repository. zhic pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 9713d67 Created CSourceMetaData module for model metadata (#7002) 9713d67 is described below commit 9713d675c64ae3075e10be5acadeef1328a44bb5 Author: manupa-arm <manupa.karunara...@arm.com> AuthorDate: Mon Dec 21 22:07:33 2020 +0000 Created CSourceMetaData module for model metadata (#7002) * Created CSourceMetaData module for model metadata * Currently, there is a MetaData module to capture constants conditionaly if the runtime modules implement const init PackedFuncs. However, this one relies on a load process in which the metadata is created on volatile memory that may be not usable in uTVM environments. * There is a need for model level metadata that is valid across all runtime modules such as the func registry when creating a system-lib. * This commit implements a CSoureMetaData module to hold func registry that collects function names from the runtime module and generates a c source file to be linked with final artifact. * Modified and added export_library for utvm Change-Id: Ie2e8e2aea1a66520f03fe8af7cc5bdf27339ea10 * Created CSourceMetaData module for model metadata * fixed llvm_module to return null pfs for get_symbol and get_const_vars Change-Id: I84810e0695d4d6fb314af2469117f965eed71b51 * Created CSourceMetaData module for model metadata *fixed bundle_deploy tests Change-Id: I0d1332a4abbb6830531784c59264021bbbd7148a * Created CSourceMetaData module for model metadata *fixed export_library not to insert "options" when targeting tar *fixed unit tests Change-Id: Ia1686889498b71af66f1a0311a059154ad3c2c3e * Created CSourceMetaData module for model metadata * enable wasm to support csource metadata module * disabled non DSOExportables from using csource metadata module Change-Id: Ie09beaad35cbc2ef738d1d24d91e249b5e099569 * Created CSourceMetaData module for model metadata * changed const pfs to be called only on external modules or DSOExportable modules Change-Id: I6ad28f166c0fc27a2548c851bf9287ec805550d1 * Created CSourceMetaData module for model metadata * CSourceMetadata module wrapper is only created for c/llvm targets Change-Id: I13cb4140c17e2e1f91d495b15a1ff7eeab9fb14d * Created CSourceMetaData module for model metadata *target should be defined to use csourcemetdata module Change-Id: Id8e55b23d0007a79c550334de2c0fec63d40171f * Created CSourceMetaData module for model metadata * reinstate llvm func registry Change-Id: I53e0754b6fb533637f08b25e98064d8c04092de4 * Created CSourceMetaData module for model metadata * addressed comments and fixed bugs Change-Id: I26401685dc803aeaf7642c865df88d683419e859 * Created CSourceMetaData module for model metadata * addressed a missed comment Change-Id: I65e65c30bc780a946f3f1b8372c40a49a5c20582 * Created CSourceMetaData module for model metadata * te build interface should only include c-source metadata if targetting "c" Change-Id: Ie23cb8c6231c1f2de6d2827084774e3510288098 * Created CSourceMetaData module for model metadata * c_source modules should be created only if they are non-DSO exportable Change-Id: I53f2f8e9caa41f133446f8881b9dc541ebeee8cc * Created CSourceMetaData module for model metadata * documetation misalignment in source_module.cc Change-Id: I83e2c29b1f2980ca65a694304720dc58a5cb7879 * Created CSourceMetaData module for model metadata * typo : same object file written as a dependency in the Makefile Change-Id: I8becc4196d286cfb6372768687b3c836799dcb78 * Created CSourceMetaData module for model metadata * removed unused param from a brief Change-Id: Ie4db2aca3b7ea147bd8c65ef5d1cc2146f530e76 * Created CSourceMetaData module for model metadata * made export library use c as the format for c source modules Change-Id: Ie2fd6204414f0fa43988a8082d18af7a3225e237 * Created CSourceMetaData module for model metadata *addressed a nit Change-Id: I6084b8c06ddfaaece295439dbab589e6e202b664 --- apps/bundle_deploy/build_model.py | 2 - python/tvm/driver/build_module.py | 12 ++ python/tvm/micro/build.py | 14 +- python/tvm/runtime/module.py | 40 +++-- src/relay/backend/build_module.cc | 10 +- src/relay/backend/contrib/codegen_c/codegen.cc | 26 ++- src/relay/backend/contrib/dnnl/codegen.cc | 3 +- src/relay/backend/vm/compiler.cc | 6 +- src/target/func_registry_generator.cc | 2 +- src/target/func_registry_generator.h | 7 +- src/target/llvm/codegen_cpu.cc | 4 +- src/target/llvm/llvm_module.cc | 14 +- src/target/source/codegen_c_host.cc | 31 +--- src/target/source/codegen_c_host.h | 8 +- src/target/source/codegen_source_base.h | 22 ++- src/target/source/source_module.cc | 183 ++++++++++++++++++--- tests/micro/qemu/test_zephyr.py | 30 +++- tests/python/relay/test_pass_partition_graph.py | 28 ++-- tests/python/unittest/test_crt.py | 1 - tests/python/unittest/test_link_params.py | 2 +- .../python/unittest/test_runtime_module_export.py | 2 +- 21 files changed, 316 insertions(+), 131 deletions(-) diff --git a/apps/bundle_deploy/build_model.py b/apps/bundle_deploy/build_model.py index 623d246..a2513c8 100644 --- a/apps/bundle_deploy/build_model.py +++ b/apps/bundle_deploy/build_model.py @@ -51,7 +51,6 @@ def build_module(opts): build_dir = os.path.abspath(opts.out_dir) if not os.path.isdir(build_dir): os.makedirs(build_dir) - lib.save(os.path.join(build_dir, file_format_str.format(name="model", ext="o"))) with open( os.path.join(build_dir, file_format_str.format(name="graph", ext="json")), "w" @@ -85,7 +84,6 @@ def build_test_module(opts): build_dir = os.path.abspath(opts.out_dir) if not os.path.isdir(build_dir): os.makedirs(build_dir) - lib.save(os.path.join(build_dir, file_format_str.format(name="test_model", ext="o"))) with open( os.path.join(build_dir, file_format_str.format(name="test_graph", ext="json")), "w" diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index dc9d741..7ad48e1 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -424,4 +424,16 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi for mdev in device_modules: if mdev: rt_mod_host.import_module(mdev) + + if not isinstance(target_host, Target): + target_host = Target(target_host) + if ( + "system-lib" in target_host.attrs + and target_host.attrs["system-lib"].value == 1 + and target_host.kind.name == "c" + ): + create_csource_metadata_module = tvm._ffi.get_global_func( + "runtime.CreateCSourceMetadataModule" + ) + return create_csource_metadata_module([rt_mod_host], target_host) return rt_mod_host diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py index 4aec9ea..cad385b 100644 --- a/python/tvm/micro/build.py +++ b/python/tvm/micro/build.py @@ -95,6 +95,7 @@ _CRT_GENERATED_LIB_OPTIONS = copy.copy(_CRT_DEFAULT_OPTIONS) # void* arg0 = (((TVMValue*)args)[0].v_handle); # int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)]; _CRT_GENERATED_LIB_OPTIONS["cflags"].append("-Wno-unused-variable") +_CRT_GENERATED_LIB_OPTIONS["ccflags"].append("-Wno-unused-variable") # Many TVM-intrinsic operators (i.e. expf, in particular) @@ -159,9 +160,6 @@ def build_static_runtime( mod_build_dir = workspace.relpath(os.path.join("build", "module")) os.makedirs(mod_build_dir) mod_src_dir = workspace.relpath(os.path.join("src", "module")) - os.makedirs(mod_src_dir) - mod_src_path = os.path.join(mod_src_dir, "module.c") - module.save(mod_src_path, "cc") libs = [] for mod_or_src_dir in (extra_libs or []) + RUNTIME_LIB_SRC_DIRS: @@ -181,7 +179,15 @@ def build_static_runtime( libs.append(compiler.library(lib_build_dir, lib_srcs, lib_opts)) - libs.append(compiler.library(mod_build_dir, [mod_src_path], generated_lib_opts)) + mod_src_dir = workspace.relpath(os.path.join("src", "module")) + os.makedirs(mod_src_dir) + libs.append( + module.export_library( + mod_build_dir, + workspace_dir=mod_src_dir, + fcompile=lambda bdir, srcs, **kwargs: compiler.library(bdir, srcs, generated_lib_opts), + ) + ) runtime_build_dir = workspace.relpath(f"build/runtime") os.makedirs(runtime_build_dir) diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index cef6173..6326796 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name, unused-import, import-outside-toplevel +# pylint: disable=invalid-name, unused-import, import-outside-toplevel, inconsistent-return-statements """Runtime Module namespace.""" import os import ctypes @@ -252,7 +252,7 @@ class Module(object): def _dso_exportable(self): return self.type_key == "llvm" or self.type_key == "c" - def export_library(self, file_name, fcompile=None, addons=None, **kwargs): + def export_library(self, file_name, fcompile=None, addons=None, workspace_dir=None, **kwargs): """Export the module and its imported device code one library. This function only works on host llvm modules. @@ -268,8 +268,19 @@ class Module(object): If fcompile has attribute object_format, will compile host library to that format. Otherwise, will use default format "o". + workspace_dir : str, optional + the path to a directory used to create intermediary + artifacts for the process exporting of the library. + If this is not provided a temporary dir will be created. + kwargs : dict, optional Additional arguments passed to fcompile + + Returns + ------- + result of fcompile() : unknown, optional + If the compilation function returns an artifact it would be returned via + export_library, if any. """ # NOTE: this function depends on contrib library features # which are only available in when TVM function is available. @@ -292,22 +303,28 @@ class Module(object): return modules = self._collect_dso_modules() - temp = _utils.tempdir() + if workspace_dir is None: + temp = _utils.tempdir() + workspace_dir = temp.temp_dir files = addons if addons else [] is_system_lib = False has_c_module = False llvm_target_triple = None for index, module in enumerate(modules): if fcompile is not None and hasattr(fcompile, "object_format"): - object_format = fcompile.object_format + if module.type_key == "c": + object_format = "c" + has_c_module = True + else: + object_format = fcompile.object_format else: if module.type_key == "llvm": object_format = "o" else: assert module.type_key == "c" - object_format = "cc" + object_format = "c" has_c_module = True - path_obj = temp.relpath("lib" + str(index) + "." + object_format) + path_obj = os.path.join(workspace_dir, f"lib{index}.{object_format}") module.save(path_obj) files.append(path_obj) is_system_lib = ( @@ -330,17 +347,20 @@ class Module(object): if self.imported_modules: if enabled("llvm") and llvm_target_triple: - path_obj = temp.relpath("devc." + object_format) + path_obj = os.path.join(workspace_dir, f"devc.{object_format}") m = _ffi_api.ModulePackImportsToLLVM(self, is_system_lib, llvm_target_triple) m.save(path_obj) files.append(path_obj) else: - path_cc = temp.relpath("devc.cc") + path_cc = os.path.join(workspace_dir, "devc.c") with open(path_cc, "w") as f: f.write(_ffi_api.ModulePackImportsToC(self, is_system_lib)) files.append(path_cc) - if has_c_module: + # The imports could contain a c module but the object format could be tar + # Thus, it would not recognize the following include paths as options + # which are there assuming a c compiler is the fcompile. + if has_c_module and not file_name.endswith(".tar"): options = [] if "options" in kwargs: opts = kwargs["options"] @@ -348,7 +368,7 @@ class Module(object): opts = options + ["-I" + path for path in find_include_path()] kwargs.update({"options": opts}) - fcompile(file_name, files, **kwargs) + return fcompile(file_name, files, **kwargs) def system_lib(): diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index a0828d1..09b0966 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -510,18 +510,14 @@ class RelayBuildModule : public runtime::ModuleNode { // If we cannot decide the target is LLVM, we create an empty CSourceModule. // The code content is initialized with ";" to prevent complaining // from CSourceModuleNode::SaveToFile. - ret_.mod = tvm::codegen::CSourceModuleCreate(";", ""); + ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array<String>{}); } } else { ret_.mod = tvm::build(lowered_funcs, target_host_); } - Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules(); - // TODO(zhiics) We should be able to completely switch to MetadataModule no - // matter whether there are external modules or not. - if (!ext_mods.empty()) { - ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods); - } + auto ext_mods = graph_codegen_->GetExternalModules(); + ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, GetTargetHost()); } private: diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index 935ac16..998393d 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -215,20 +215,19 @@ class CodegenC : public MemoizedExprTranslator<std::vector<Output>>, public Code class CSourceCodegen : public CSourceModuleCodegenBase { public: - std::pair<std::string, Array<String>> GenCFunc(const Function& func) { + std::tuple<Array<String>, String, String> GenCFunc(const Function& func) { ICHECK(func.defined()) << "Input error: expect a Relay function."; - - // Record the external symbol for runtime lookup. - auto sid = GetExtSymbol(func); - - CodegenC builder(sid); + CodegenC builder(GetExtSymbol(func)); auto out = builder.VisitExpr(func->body); - code_stream_ << builder.JIT(out); - - return {sid, builder.const_vars_}; + return std::make_tuple(builder.const_vars_, builder.ext_func_id_, builder.JIT(out)); } runtime::Module CreateCSourceModule(const ObjectRef& ref) override { + ICHECK(ref->IsInstance<FunctionNode>()); + auto res = GenCFunc(Downcast<Function>(ref)); + Array<String> variables = std::get<0>(res); + String func_name = std::get<1>(res); + // Create headers code_stream_ << "#include <cstring>\n"; code_stream_ << "#include <vector>\n"; @@ -259,18 +258,13 @@ class CSourceCodegen : public CSourceModuleCodegenBase { )op_macro"; code_stream_ << operator_macro << "\n\n"; - - ICHECK(ref->IsInstance<FunctionNode>()); - auto res = GenCFunc(Downcast<Function>(ref)); + code_stream_ << std::get<2>(res); std::string code = code_stream_.str(); - String sym = std::get<0>(res); - Array<String> variables = std::get<1>(res); - // Create a CSource module const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); ICHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; - return (*pf)(code, "c", sym, variables); + return (*pf)(code, "c", Array<String>{func_name}, variables); } private: diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index bfc5c77..c9a5828 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -413,7 +413,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase { // Create a CSource module const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); ICHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; - return (*pf)(code, "c", sym, variables); + // TODO(@manupa-arm): pass the function names to enable system-lib creation + return (*pf)(code, "c", Array<String>{sym}, variables); } private: diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index bed2510..8fbe31e 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1146,11 +1146,9 @@ void VMCompiler::Codegen() { } else { // There is no function handled by TVM. We create a virtual main module // to make sure a DSO module will be also available. - exec_->lib = codegen::CSourceModuleCreate(";", ""); - } - if (!ext_mods.empty()) { - exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods); + exec_->lib = codegen::CSourceModuleCreate(";", "", Array<String>{}); } + exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods, target_host_); } ExprDeviceMap VMCompiler::AnalyzeContext() const { diff --git a/src/target/func_registry_generator.cc b/src/target/func_registry_generator.cc index 402d0f8..7c948d5 100644 --- a/src/target/func_registry_generator.cc +++ b/src/target/func_registry_generator.cc @@ -29,7 +29,7 @@ namespace tvm { namespace target { -std::string GenerateFuncRegistryNames(const std::vector<std::string>& function_names) { +std::string GenerateFuncRegistryNames(const Array<String>& function_names) { std::stringstream ss; ss << (unsigned char)(function_names.size()); for (auto f : function_names) { diff --git a/src/target/func_registry_generator.h b/src/target/func_registry_generator.h index 362fca8..fb59648 100644 --- a/src/target/func_registry_generator.h +++ b/src/target/func_registry_generator.h @@ -24,13 +24,18 @@ #ifndef TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_ #define TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_ +#include <tvm/runtime/container.h> + #include <string> #include <vector> +using tvm::runtime::Array; +using tvm::runtime::String; + namespace tvm { namespace target { -std::string GenerateFuncRegistryNames(const std::vector<std::string>& function_names); +std::string GenerateFuncRegistryNames(const Array<String>& function_names); } // namespace target } // namespace tvm diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index fea5f80..6143e70 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -794,10 +794,10 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() { void CodeGenCPU::AddStartupFunction() { if (registry_functions_.size() != 0) { ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be defined for C runtime"; - std::vector<std::string> symbols; + Array<String> symbols; std::vector<llvm::Constant*> funcs; for (auto sym : registry_functions_) { - symbols.emplace_back(sym.first); + symbols.push_back(sym.first); funcs.emplace_back(llvm::ConstantExpr::getBitCast( sym.second, ftype_tvm_backend_packed_c_func_->getPointerTo())); } diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index a0ab49d..43d2097 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -60,6 +60,13 @@ class LLVMModuleNode final : public runtime::ModuleNode { if (name == "__tvm_is_system_module") { bool flag = (mptr_->getFunction("__tvm_module_startup") != nullptr); return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; }); + } else if (name == "get_func_names") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->function_names_; }); + } else if (name == "get_symbol") { + return PackedFunc(nullptr); + } else if (name == "get_const_vars") { + return PackedFunc(nullptr); } else if (name == "_get_target_triple") { std::string target_triple = tm_->getTargetTriple().str(); // getTargetTriple() doesn't include other flags besides the triple. Add back flags which are @@ -218,9 +225,10 @@ class LLVMModuleNode final : public runtime::ModuleNode { ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "Can only lower IR Module with PrimFuncs, but got " << kv.second->GetTypeKey(); auto f = Downcast<PrimFunc>(kv.second); + auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); + ICHECK(global_symbol.defined()); + function_names_.push_back(global_symbol.value()); if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()); entry_func = global_symbol.value(); } funcs.push_back(f); @@ -377,6 +385,8 @@ class LLVMModuleNode final : public runtime::ModuleNode { std::unique_ptr<llvm::Module> module_; // the context. std::shared_ptr<llvm::LLVMContext> ctx_; + /* \brief names of the functions declared in this module */ + Array<String> function_names_; }; TVM_REGISTER_GLOBAL("target.build.llvm") diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 0a19fc1..bee5441 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -55,7 +55,7 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) { auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute"; - function_names_.emplace_back(global_symbol.value()); + function_names_.push_back(global_symbol.value()); CodeGenC::AddFunction(f); } @@ -73,7 +73,7 @@ void CodeGenCHost::LinkParameters(Map<String, LinkedParam> params) { << " out_ret_tcode[0] = " << kTVMNullptr << ";\n" << " return 0;\n"; - function_names_.emplace_back(tvm::runtime::symbol::tvm_lookup_linked_param); + function_names_.push_back(tvm::runtime::symbol::tvm_lookup_linked_param); for (auto kv : params) { decl_stream << "\n" << "#ifdef __cplusplus\n" @@ -322,29 +322,6 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T* op, const char* compare, << "? (" << a_id << ") : (" << b_id << "))"; } -void CodeGenCHost::GenerateFuncRegistry() { - decl_stream << "#include <tvm/runtime/crt/module.h>\n"; - stream << "static TVMBackendPackedCFunc _tvm_func_array[] = {\n"; - for (auto f : function_names_) { - stream << " (TVMBackendPackedCFunc)" << f << ",\n"; - } - stream << "};\n"; - auto registry = target::GenerateFuncRegistryNames(function_names_); - stream << "static const TVMFuncRegistry _tvm_func_registry = {\n" - << " \"" << ::tvm::support::StrEscape(registry.data(), registry.size(), true) << "\"," - << " _tvm_func_array,\n" - << "};\n"; -} - -void CodeGenCHost::GenerateCrtSystemLib() { - stream << "static const TVMModule _tvm_system_lib = {\n" - << " &_tvm_func_registry,\n" - << "};\n" - << "const TVMModule* TVMSystemLibEntryPoint(void) {\n" - << " return &_tvm_system_lib;\n" - << "}\n"; -} - runtime::Module BuildCHost(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; @@ -380,12 +357,10 @@ runtime::Module BuildCHost(IRModule mod, Target target) { if (target->GetAttr<Bool>("system-lib").value_or(Bool(false))) { ICHECK_EQ(target->GetAttr<String>("runtime").value_or(""), "c") << "c target only supports generating C runtime SystemLibs"; - cg.GenerateFuncRegistry(); - cg.GenerateCrtSystemLib(); } std::string code = cg.Finish(); - return CSourceModuleCreate(code, "c"); + return CSourceModuleCreate(code, "c", cg.GetFunctionNames()); } TVM_REGISTER_GLOBAL("target.build.c").set_body_typed(BuildCHost); diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index b54b6fb..97fe7ab 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -59,18 +59,14 @@ class CodeGenCHost final : public CodeGenC { void VisitStmt_(const AssertStmtNode* op) final; // NOLINT(*) - /*! \brief Generate C runtime FuncRegistry global constant. */ - void GenerateFuncRegistry(); - - /*! \brief Generate C runtime SystemLib entry point. */ - void GenerateCrtSystemLib(); + Array<String> GetFunctionNames() { return function_names_; } private: std::string module_name_; /* \brief tracks declared global variables which live despite GetUniqueName */ std::set<std::string> declared_globals_; /* \brief names of the functions declared in this module */ - std::vector<std::string> function_names_; + Array<String> function_names_; /*! \brief whether to emit asserts in the resulting C code */ bool emit_asserts_; diff --git a/src/target/source/codegen_source_base.h b/src/target/source/codegen_source_base.h index 7e5e403..ed838f8 100644 --- a/src/target/source/codegen_source_base.h +++ b/src/target/source/codegen_source_base.h @@ -136,25 +136,26 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt); * \brief Create a C source module for viewing and compiling GCC code. * \param code The code to be viewed. * \param fmt The code format. - * \param symbol The symbol that the c source module represents. + * \param func_names The name of functions inside the runtime module. * \param const_vars. The constant variables that the c source module needs. * \return The created module. */ runtime::Module CSourceModuleCreate(const String& code, const String& fmt, - const String& symbol = "", + const Array<String>& func_names, const Array<String>& const_vars = {}); /*! * \brief Wrap the submodules in a metadata module. * \param params The variable to constant mapping that is collected by the host * module. - * \param dso_module The host module to be wrapped. - * \param modules The modules to be wrapped. + * \param target_module The main TIR-lowered internal runtime module + * \param modules All the external modules that needs to be imported inside the metadata module(s). + * \param target The target that all the modules are compiled for * \return The wrapped module. */ runtime::Module CreateMetadataModule( - const std::unordered_map<std::string, runtime::NDArray>& params, - const runtime::Module& dso_module, const Array<runtime::Module>& modules); + const std::unordered_map<std::string, runtime::NDArray>& params, runtime::Module target_module, + const Array<runtime::Module>& ext_modules, Target target); /*! * \brief Create a source module for viewing and limited saving for device. @@ -167,6 +168,15 @@ runtime::Module CreateMetadataModule( runtime::Module DeviceSourceModuleCreate( std::string data, std::string fmt, std::unordered_map<std::string, runtime::FunctionInfo> fmap, std::string type_key, std::function<std::string(const std::string&)> fget_source = nullptr); + +/*! + * \brief Wrap the submodules that are to be wrapped in a c-source metadata module. + * \param modules The modules to be wrapped. + * \param target the target the modules are compiled for. + * \return The wrapped module. + */ +runtime::Module CreateCSourceMetadataModule(const Array<runtime::Module>& modules, Target target); + } // namespace codegen } // namespace tvm #endif // TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_ diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 3be658a..4b4770a 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -27,6 +27,8 @@ #include "../../runtime/file_utils.h" #include "../../runtime/meta_data.h" +#include "../../support/str_escape.h" +#include "../func_registry_generator.h" #include "codegen_source_base.h" namespace tvm { @@ -46,40 +48,66 @@ using runtime::SaveBinaryToFile; * codegens, such as graph runtime codegen and the vm compiler. * * \param params The metadata for initialization of all modules. - * \param dso_module The DSO module that contains TVM primitives. - * \param modules The submodules that will be wrapped, e.g. CSource modules that - * contain vendor library calls or customized runtime modules. - * + * \param target_module the internal module that is compiled by tvm. + * \param ext_modules The external modules that needs to be imported inside the metadata + * module(s). + * \param target The target that all the modules are compiled for * \return The created metadata module that manages initialization of metadata. */ runtime::Module CreateMetadataModule( const std::unordered_map<std::string, runtime::NDArray>& params, - const runtime::Module& dso_module, const Array<runtime::Module>& modules) { + tvm::runtime::Module target_module, const Array<runtime::Module>& ext_modules, Target target) { + Array<tvm::runtime::Module> csource_modules; + Array<tvm::runtime::Module> binary_modules; + + auto DSOExportable = [](tvm::runtime::Module& mod) { + return !std::strcmp(mod->type_key(), "llvm") || !std::strcmp(mod->type_key(), "c"); + }; + // Wrap all submodules in the initialization wrapper. std::unordered_map<std::string, std::vector<std::string>> sym_metadata; - for (runtime::Module it : modules) { - auto pf_sym = it.GetFunction("get_symbol"); - auto pf_var = it.GetFunction("get_const_vars"); + for (tvm::runtime::Module mod : ext_modules) { + auto pf_sym = mod.GetFunction("get_symbol"); + auto pf_var = mod.GetFunction("get_const_vars"); + std::vector<std::string> arrays; if (pf_sym != nullptr && pf_var != nullptr) { String symbol = pf_sym(); Array<String> variables = pf_var(); - std::vector<std::string> arrays; for (size_t i = 0; i < variables.size(); i++) { arrays.push_back(variables[i].operator std::string()); } ICHECK_EQ(sym_metadata.count(symbol), 0U) << "Found duplicated symbol: " << symbol; sym_metadata[symbol] = arrays; } + // We only need loading of serialized constant data + // if there are constants present and required by the + // runtime module to be initialized by the binary + // metadata module. If not rest of the modules are + // wrapped in c-source metadata module. + + // TODO(@manupa-arm) : we should be able to use csource_metadata + // if the variables are empty when all the runtime modules implement get_func_names + if (arrays.empty() && DSOExportable(mod) && target->kind->name == "c") { + csource_modules.push_back(mod); + } else { + binary_modules.push_back(mod); + } } - // Wrap the modules. - runtime::Module init_m = runtime::MetadataModuleCreate(params, sym_metadata); - init_m.Import(dso_module); - for (const auto& it : modules) { - init_m.Import(it); + if (target.defined() && target->kind->name == "c") { + csource_modules.push_back(target_module); + target_module = CreateCSourceMetadataModule(csource_modules, target); } - return init_m; + if (!binary_modules.empty()) { + runtime::Module binary_meta_mod = runtime::MetadataModuleCreate(params, sym_metadata); + binary_meta_mod.Import(target_module); + for (const auto& it : binary_modules) { + binary_meta_mod.Import(it); + } + return binary_meta_mod; + } + return target_module; } // Simulator function @@ -109,18 +137,25 @@ runtime::Module SourceModuleCreate(std::string code, std::string fmt) { // Simulator function class CSourceModuleNode : public runtime::ModuleNode { public: - CSourceModuleNode(const std::string& code, const std::string& fmt, const std::string& symbol, - const Array<String>& const_vars) - : code_(code), fmt_(fmt), symbol_(symbol), const_vars_(const_vars) {} + CSourceModuleNode(const std::string& code, const std::string& fmt, + const Array<String>& func_names, const Array<String>& const_vars) + : code_(code), fmt_(fmt), const_vars_(const_vars), func_names_(func_names) {} const char* type_key() const { return "c"; } PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { + // Currently c-source module is used as demonstration purposes with binary metadata module + // that expects get_symbol interface. When c-source module is used as external module, it + // will only contain one function. However, when its used as an internal module (e.g., target + // "c") it can have many functions. if (name == "get_symbol") { return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_; }); + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_[0]; }); } else if (name == "get_const_vars") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->const_vars_; }); + } else if (name == "get_func_names") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->func_names_; }); } else { return PackedFunc(nullptr); } @@ -131,7 +166,7 @@ class CSourceModuleNode : public runtime::ModuleNode { void SaveToFile(const std::string& file_name, const std::string& format) final { std::string fmt = GetFileFormat(file_name, format); std::string meta_file = GetMetaFilePath(file_name); - if (fmt == "cc") { + if (fmt == "c") { ICHECK_NE(code_.length(), 0); SaveBinaryToFile(file_name, code_); } else { @@ -142,17 +177,109 @@ class CSourceModuleNode : public runtime::ModuleNode { protected: std::string code_; std::string fmt_; - std::string symbol_; Array<String> const_vars_; + Array<String> func_names_; }; -runtime::Module CSourceModuleCreate(const String& code, const String& fmt, const String& symbol, +runtime::Module CSourceModuleCreate(const String& code, const String& fmt, + const Array<String>& func_names, const Array<String>& const_vars) { auto n = make_object<CSourceModuleNode>(code.operator std::string(), fmt.operator std::string(), - symbol.operator std::string(), const_vars); + func_names, const_vars); return runtime::Module(n); } +class CSourceMetadataModuleNode : public runtime::ModuleNode { + public: + CSourceMetadataModuleNode(const Array<String>& func_names, const std::string& fmt, Target target) + : fmt_(fmt), func_names_(func_names), target_(target) { + CreateSource(); + } + const char* type_key() const { return "c"; } + + std::string GetSource(const std::string& format) final { return code_.str(); } + + PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) final { + return PackedFunc(nullptr); + } + + void SaveToFile(const std::string& file_name, const std::string& format) final { + std::string fmt = GetFileFormat(file_name, format); + std::string meta_file = GetMetaFilePath(file_name); + if (fmt == "c") { + auto code_str = code_.str(); + ICHECK_NE(code_str.length(), 0); + SaveBinaryToFile(file_name, code_str); + } else { + ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_; + } + } + + protected: + std::stringstream code_; + std::string fmt_; + Array<String> func_names_; + Target target_; + + void CreateFuncRegistry() { + code_ << "#include <tvm/runtime/crt/module.h>\n"; + for (const auto& fname : func_names_) { + code_ << "#ifdef __cplusplus\n"; + code_ << "extern \"C\"\n"; + code_ << "#endif\n"; + code_ << "TVM_DLL int32_t " << fname.data(); + code_ << "(TVMValue* args, int* type_code, int num_args, TVMValue* out_value, int* " + "out_type_code);\n"; + } + code_ << "static TVMBackendPackedCFunc _tvm_func_array[] = {\n"; + for (auto f : func_names_) { + code_ << " (TVMBackendPackedCFunc)" << f << ",\n"; + } + code_ << "};\n"; + auto registry = target::GenerateFuncRegistryNames(func_names_); + code_ << "static const TVMFuncRegistry _tvm_func_registry = {\n" + << " \"" << ::tvm::support::StrEscape(registry.data(), registry.size(), true) << "\"," + << " _tvm_func_array,\n" + << "};\n"; + } + + void GenerateCrtSystemLib() { + code_ << "static const TVMModule _tvm_system_lib = {\n" + << " &_tvm_func_registry,\n" + << "};\n" + << "const TVMModule* TVMSystemLibEntryPoint(void) {\n" + << " return &_tvm_system_lib;\n" + << "}\n"; + } + + void CreateSource() { + if (target_->GetAttr<Bool>("system-lib").value_or(Bool(false)) && !func_names_.empty()) { + CreateFuncRegistry(); + GenerateCrtSystemLib(); + } + code_ << ";"; + } +}; + +runtime::Module CreateCSourceMetadataModule(const Array<runtime::Module>& modules, Target target) { + Array<String> func_names; + for (runtime::Module mod : modules) { + auto pf_funcs = mod.GetFunction("get_func_names"); + if (pf_funcs != nullptr) { + Array<String> func_names_ = pf_funcs(); + for (const auto& fname : func_names_) { + func_names.push_back(fname); + } + } + } + auto n = make_object<CSourceMetadataModuleNode>(func_names, "cc", target); + auto csrc_metadata_module = runtime::Module(n); + for (const auto& mod : modules) { + csrc_metadata_module.Import(mod); + } + return std::move(csrc_metadata_module); +} + // supports limited save without cross compile class DeviceSourceModuleNode final : public runtime::ModuleNode { public: @@ -209,8 +336,14 @@ runtime::Module DeviceSourceModuleCreate( TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate); TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate") - .set_body_typed([](String code, String fmt, String symbol, Array<String> const_vars) { - return CSourceModuleCreate(code, fmt, symbol, const_vars); + .set_body_typed([](String code, String fmt, Array<String> func_names, + Array<String> const_vars) { + return CSourceModuleCreate(code, fmt, func_names, const_vars); + }); + +TVM_REGISTER_GLOBAL("runtime.CreateCSourceMetadataModule") + .set_body_typed([](const Array<runtime::Module>& modules, Target target) { + return CreateCSourceMetadataModule(modules, target); }); } // namespace codegen diff --git a/tests/micro/qemu/test_zephyr.py b/tests/micro/qemu/test_zephyr.py index 2213203..3e73307 100644 --- a/tests/micro/qemu/test_zephyr.py +++ b/tests/micro/qemu/test_zephyr.py @@ -29,7 +29,7 @@ import numpy as np import tvm import tvm.rpc import tvm.micro -import tvm.relay +import tvm.relay as relay from tvm.micro.contrib import zephyr from tvm.contrib import utils @@ -143,5 +143,33 @@ def test_compile_runtime(platform): test_basic_add(sess) +def test_relay(platform): + """Testing a simple relay graph""" + model, zephyr_board = PLATFORMS[platform] + shape = (10,) + dtype = "int8" + + # Construct Relay program. + x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) + xx = relay.multiply(x, x) + z = relay.add(xx, relay.const(np.ones(shape=shape, dtype=dtype))) + func = relay.Function([x], z) + + target = tvm.target.target.micro(model) + with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}): + graph, mod, params = tvm.relay.build(func, target=target) + + with _make_session(model, target, zephyr_board, mod) as session: + graph_mod = tvm.micro.create_local_graph_runtime( + graph, session.get_system_lib(), session.context + ) + graph_mod.set_input(**params) + x_in = np.random.randint(10, size=shape[0], dtype=dtype) + graph_mod.run(x=x_in) + result = graph_mod.get_output(0).asnumpy() + tvm.testing.assert_allclose(graph_mod.get_input(0).asnumpy(), x_in) + tvm.testing.assert_allclose(result, x_in * x_in + 1) + + if __name__ == "__main__": sys.exit(pytest.main([os.path.dirname(__file__)] + sys.argv[1:])) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 059d0b4..d8f674e 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -273,19 +273,23 @@ def test_multi_node_compiler(): map_inputs = {"w{}".format(i): w_data[i] for i in range(8)} map_inputs["x"] = x_data - check_result( - mod, - map_inputs, - (30, 10), - np.concatenate( - ( - ((x_data + w_data[0]) - w_data[1]) * w_data[2], - ((x_data + w_data[3]) - w_data[4]) * w_data[5], - x_data + w_data[6] - w_data[7], + + targets = ["llvm", "c -runtime=c --system-lib"] + for tgt in targets: + check_result( + mod, + map_inputs, + (30, 10), + np.concatenate( + ( + ((x_data + w_data[0]) - w_data[1]) * w_data[2], + ((x_data + w_data[3]) - w_data[4]) * w_data[5], + x_data + w_data[6] - w_data[7], + ), + axis=0, ), - axis=0, - ), - ) + target=tgt, + ) def test_extern_ccompiler_single_op(): diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py index 07a4cfc..1d84d4e 100644 --- a/tests/python/unittest/test_crt.py +++ b/tests/python/unittest/test_crt.py @@ -49,7 +49,6 @@ def _make_sess_from_op(workspace, op_name, sched, arg_bufs): def _make_session(workspace, mod): compiler = tvm.micro.DefaultCompiler(target=TARGET) opts = tvm.micro.default_options(os.path.join(tvm.micro.CRT_ROOT_DIR, "host")) - micro_binary = tvm.micro.build_static_runtime( # the x86 compiler *expects* you to give the exact same dictionary for both # lib_opts and bin_opts. so the library compiler is mutating lib_opts and diff --git a/tests/python/unittest/test_link_params.py b/tests/python/unittest/test_link_params.py index e3bd634..da87a31 100644 --- a/tests/python/unittest/test_link_params.py +++ b/tests/python/unittest/test_link_params.py @@ -266,7 +266,7 @@ def test_c_link_params(): lib = tvm.relay.build(mod, target, params=param_init) assert set(lib.params.keys()) == {"p0", "p1"} # NOTE: op folded - src = lib.lib.get_source() + src = lib.lib.imported_modules[0].get_source() lib.lib.save("test.c", "cc") c_dtype = _get_c_datatype(dtype) src_lines = src.split("\n") diff --git a/tests/python/unittest/test_runtime_module_export.py b/tests/python/unittest/test_runtime_module_export.py index 88b7af9..af9a8ab 100644 --- a/tests/python/unittest/test_runtime_module_export.py +++ b/tests/python/unittest/test_runtime_module_export.py @@ -58,7 +58,7 @@ def generate_engine_module(): import tvm.runtime._ffi_api gen_engine_header() - csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", "", None) + csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", [], None) return csource_module