This is an automated email from the ASF dual-hosted git repository.

zhic pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9713d67  Created CSourceMetaData module for model metadata (#7002)
9713d67 is described below

commit 9713d675c64ae3075e10be5acadeef1328a44bb5
Author: manupa-arm <manupa.karunara...@arm.com>
AuthorDate: Mon Dec 21 22:07:33 2020 +0000

    Created CSourceMetaData module for model metadata (#7002)
    
    * Created CSourceMetaData module for model metadata
    
    * Currently, there is a MetaData module to capture constants
      conditionaly if the runtime modules implement const init
      PackedFuncs. However, this one relies on a load process
      in which the metadata is created on volatile memory that
      may be not usable in uTVM environments.
    * There is a need for model level metadata that is valid
      across all runtime modules such as the func registry
      when creating a system-lib.
    * This commit implements a CSoureMetaData module to hold
      func registry that collects function names from the
      runtime module and generates a c source file to be
      linked with final artifact.
    * Modified and added export_library for utvm
    
    Change-Id: Ie2e8e2aea1a66520f03fe8af7cc5bdf27339ea10
    
    * Created CSourceMetaData module for model metadata
    
    * fixed llvm_module to return null pfs for
      get_symbol and get_const_vars
    
    Change-Id: I84810e0695d4d6fb314af2469117f965eed71b51
    
    * Created CSourceMetaData module for model metadata
    
    *fixed bundle_deploy tests
    
    Change-Id: I0d1332a4abbb6830531784c59264021bbbd7148a
    
    * Created CSourceMetaData module for model metadata
    
    *fixed export_library not to insert "options" when targeting tar
    *fixed unit tests
    
    Change-Id: Ia1686889498b71af66f1a0311a059154ad3c2c3e
    
    * Created CSourceMetaData module for model metadata
    
    * enable wasm to support csource metadata module
    * disabled non DSOExportables from using csource metadata module
    
    Change-Id: Ie09beaad35cbc2ef738d1d24d91e249b5e099569
    
    * Created CSourceMetaData module for model metadata
    
    * changed const pfs to be called only on external modules
      or DSOExportable modules
    
    Change-Id: I6ad28f166c0fc27a2548c851bf9287ec805550d1
    
    * Created CSourceMetaData module for model metadata
    
    * CSourceMetadata module wrapper is only created for c/llvm targets
    
    Change-Id: I13cb4140c17e2e1f91d495b15a1ff7eeab9fb14d
    
    * Created CSourceMetaData module for model metadata
    
    *target should be defined to use csourcemetdata module
    
    Change-Id: Id8e55b23d0007a79c550334de2c0fec63d40171f
    
    * Created CSourceMetaData module for model metadata
    
    * reinstate llvm func registry
    
    Change-Id: I53e0754b6fb533637f08b25e98064d8c04092de4
    
    * Created CSourceMetaData module for model metadata
    
    * addressed comments and fixed bugs
    
    Change-Id: I26401685dc803aeaf7642c865df88d683419e859
    
    * Created CSourceMetaData module for model metadata
    
    * addressed a missed comment
    
    Change-Id: I65e65c30bc780a946f3f1b8372c40a49a5c20582
    
    * Created CSourceMetaData module for model metadata
    
    * te build interface should only include c-source metadata if
      targetting "c"
    
    Change-Id: Ie23cb8c6231c1f2de6d2827084774e3510288098
    
    * Created CSourceMetaData module for model metadata
    
    * c_source modules should be created only if they are
      non-DSO exportable
    
    Change-Id: I53f2f8e9caa41f133446f8881b9dc541ebeee8cc
    
    * Created CSourceMetaData module for model metadata
    
    * documetation misalignment in source_module.cc
    
    Change-Id: I83e2c29b1f2980ca65a694304720dc58a5cb7879
    
    * Created CSourceMetaData module for model metadata
    
    * typo : same object file written as a dependency in the Makefile
    
    Change-Id: I8becc4196d286cfb6372768687b3c836799dcb78
    
    * Created CSourceMetaData module for model metadata
    
    * removed unused param from a brief
    
    Change-Id: Ie4db2aca3b7ea147bd8c65ef5d1cc2146f530e76
    
    * Created CSourceMetaData module for model metadata
    
    * made export library use c as the format for c source modules
    
    Change-Id: Ie2fd6204414f0fa43988a8082d18af7a3225e237
    
    * Created CSourceMetaData module for model metadata
    
    *addressed a nit
    
    Change-Id: I6084b8c06ddfaaece295439dbab589e6e202b664
---
 apps/bundle_deploy/build_model.py                  |   2 -
 python/tvm/driver/build_module.py                  |  12 ++
 python/tvm/micro/build.py                          |  14 +-
 python/tvm/runtime/module.py                       |  40 +++--
 src/relay/backend/build_module.cc                  |  10 +-
 src/relay/backend/contrib/codegen_c/codegen.cc     |  26 ++-
 src/relay/backend/contrib/dnnl/codegen.cc          |   3 +-
 src/relay/backend/vm/compiler.cc                   |   6 +-
 src/target/func_registry_generator.cc              |   2 +-
 src/target/func_registry_generator.h               |   7 +-
 src/target/llvm/codegen_cpu.cc                     |   4 +-
 src/target/llvm/llvm_module.cc                     |  14 +-
 src/target/source/codegen_c_host.cc                |  31 +---
 src/target/source/codegen_c_host.h                 |   8 +-
 src/target/source/codegen_source_base.h            |  22 ++-
 src/target/source/source_module.cc                 | 183 ++++++++++++++++++---
 tests/micro/qemu/test_zephyr.py                    |  30 +++-
 tests/python/relay/test_pass_partition_graph.py    |  28 ++--
 tests/python/unittest/test_crt.py                  |   1 -
 tests/python/unittest/test_link_params.py          |   2 +-
 .../python/unittest/test_runtime_module_export.py  |   2 +-
 21 files changed, 316 insertions(+), 131 deletions(-)

diff --git a/apps/bundle_deploy/build_model.py 
b/apps/bundle_deploy/build_model.py
index 623d246..a2513c8 100644
--- a/apps/bundle_deploy/build_model.py
+++ b/apps/bundle_deploy/build_model.py
@@ -51,7 +51,6 @@ def build_module(opts):
         build_dir = os.path.abspath(opts.out_dir)
         if not os.path.isdir(build_dir):
             os.makedirs(build_dir)
-
         lib.save(os.path.join(build_dir, file_format_str.format(name="model", 
ext="o")))
         with open(
             os.path.join(build_dir, file_format_str.format(name="graph", 
ext="json")), "w"
@@ -85,7 +84,6 @@ def build_test_module(opts):
         build_dir = os.path.abspath(opts.out_dir)
         if not os.path.isdir(build_dir):
             os.makedirs(build_dir)
-
         lib.save(os.path.join(build_dir, 
file_format_str.format(name="test_model", ext="o")))
         with open(
             os.path.join(build_dir, file_format_str.format(name="test_graph", 
ext="json")), "w"
diff --git a/python/tvm/driver/build_module.py 
b/python/tvm/driver/build_module.py
index dc9d741..7ad48e1 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -424,4 +424,16 @@ def build(inputs, args=None, target=None, 
target_host=None, name="default_functi
     for mdev in device_modules:
         if mdev:
             rt_mod_host.import_module(mdev)
+
+    if not isinstance(target_host, Target):
+        target_host = Target(target_host)
+    if (
+        "system-lib" in target_host.attrs
+        and target_host.attrs["system-lib"].value == 1
+        and target_host.kind.name == "c"
+    ):
+        create_csource_metadata_module = tvm._ffi.get_global_func(
+            "runtime.CreateCSourceMetadataModule"
+        )
+        return create_csource_metadata_module([rt_mod_host], target_host)
     return rt_mod_host
diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py
index 4aec9ea..cad385b 100644
--- a/python/tvm/micro/build.py
+++ b/python/tvm/micro/build.py
@@ -95,6 +95,7 @@ _CRT_GENERATED_LIB_OPTIONS = copy.copy(_CRT_DEFAULT_OPTIONS)
 #   void* arg0 = (((TVMValue*)args)[0].v_handle);
 #   int32_t arg0_code = ((int32_t*)arg_type_ids)[(0)];
 _CRT_GENERATED_LIB_OPTIONS["cflags"].append("-Wno-unused-variable")
+_CRT_GENERATED_LIB_OPTIONS["ccflags"].append("-Wno-unused-variable")
 
 
 # Many TVM-intrinsic operators (i.e. expf, in particular)
@@ -159,9 +160,6 @@ def build_static_runtime(
     mod_build_dir = workspace.relpath(os.path.join("build", "module"))
     os.makedirs(mod_build_dir)
     mod_src_dir = workspace.relpath(os.path.join("src", "module"))
-    os.makedirs(mod_src_dir)
-    mod_src_path = os.path.join(mod_src_dir, "module.c")
-    module.save(mod_src_path, "cc")
 
     libs = []
     for mod_or_src_dir in (extra_libs or []) + RUNTIME_LIB_SRC_DIRS:
@@ -181,7 +179,15 @@ def build_static_runtime(
 
         libs.append(compiler.library(lib_build_dir, lib_srcs, lib_opts))
 
-    libs.append(compiler.library(mod_build_dir, [mod_src_path], 
generated_lib_opts))
+    mod_src_dir = workspace.relpath(os.path.join("src", "module"))
+    os.makedirs(mod_src_dir)
+    libs.append(
+        module.export_library(
+            mod_build_dir,
+            workspace_dir=mod_src_dir,
+            fcompile=lambda bdir, srcs, **kwargs: compiler.library(bdir, srcs, 
generated_lib_opts),
+        )
+    )
 
     runtime_build_dir = workspace.relpath(f"build/runtime")
     os.makedirs(runtime_build_dir)
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index cef6173..6326796 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# pylint: disable=invalid-name, unused-import, import-outside-toplevel
+# pylint: disable=invalid-name, unused-import, import-outside-toplevel, 
inconsistent-return-statements
 """Runtime Module namespace."""
 import os
 import ctypes
@@ -252,7 +252,7 @@ class Module(object):
     def _dso_exportable(self):
         return self.type_key == "llvm" or self.type_key == "c"
 
-    def export_library(self, file_name, fcompile=None, addons=None, **kwargs):
+    def export_library(self, file_name, fcompile=None, addons=None, 
workspace_dir=None, **kwargs):
         """Export the module and its imported device code one library.
 
         This function only works on host llvm modules.
@@ -268,8 +268,19 @@ class Module(object):
             If fcompile has attribute object_format, will compile host library
             to that format. Otherwise, will use default format "o".
 
+        workspace_dir : str, optional
+            the path to a directory used to create intermediary
+            artifacts for the process exporting of the library.
+            If this is not provided a temporary dir will be created.
+
         kwargs : dict, optional
             Additional arguments passed to fcompile
+
+        Returns
+        -------
+        result of fcompile()  : unknown, optional
+            If the compilation function returns an artifact it would be 
returned via
+            export_library, if any.
         """
         # NOTE: this function depends on contrib library features
         # which are only available in when TVM function is available.
@@ -292,22 +303,28 @@ class Module(object):
             return
 
         modules = self._collect_dso_modules()
-        temp = _utils.tempdir()
+        if workspace_dir is None:
+            temp = _utils.tempdir()
+            workspace_dir = temp.temp_dir
         files = addons if addons else []
         is_system_lib = False
         has_c_module = False
         llvm_target_triple = None
         for index, module in enumerate(modules):
             if fcompile is not None and hasattr(fcompile, "object_format"):
-                object_format = fcompile.object_format
+                if module.type_key == "c":
+                    object_format = "c"
+                    has_c_module = True
+                else:
+                    object_format = fcompile.object_format
             else:
                 if module.type_key == "llvm":
                     object_format = "o"
                 else:
                     assert module.type_key == "c"
-                    object_format = "cc"
+                    object_format = "c"
                     has_c_module = True
-            path_obj = temp.relpath("lib" + str(index) + "." + object_format)
+            path_obj = os.path.join(workspace_dir, 
f"lib{index}.{object_format}")
             module.save(path_obj)
             files.append(path_obj)
             is_system_lib = (
@@ -330,17 +347,20 @@ class Module(object):
 
         if self.imported_modules:
             if enabled("llvm") and llvm_target_triple:
-                path_obj = temp.relpath("devc." + object_format)
+                path_obj = os.path.join(workspace_dir, f"devc.{object_format}")
                 m = _ffi_api.ModulePackImportsToLLVM(self, is_system_lib, 
llvm_target_triple)
                 m.save(path_obj)
                 files.append(path_obj)
             else:
-                path_cc = temp.relpath("devc.cc")
+                path_cc = os.path.join(workspace_dir, "devc.c")
                 with open(path_cc, "w") as f:
                     f.write(_ffi_api.ModulePackImportsToC(self, is_system_lib))
                 files.append(path_cc)
 
-        if has_c_module:
+        # The imports could contain a c module but the object format could be 
tar
+        # Thus, it would not recognize the following include paths as options
+        # which are there assuming a c compiler is the fcompile.
+        if has_c_module and not file_name.endswith(".tar"):
             options = []
             if "options" in kwargs:
                 opts = kwargs["options"]
@@ -348,7 +368,7 @@ class Module(object):
             opts = options + ["-I" + path for path in find_include_path()]
             kwargs.update({"options": opts})
 
-        fcompile(file_name, files, **kwargs)
+        return fcompile(file_name, files, **kwargs)
 
 
 def system_lib():
diff --git a/src/relay/backend/build_module.cc 
b/src/relay/backend/build_module.cc
index a0828d1..09b0966 100644
--- a/src/relay/backend/build_module.cc
+++ b/src/relay/backend/build_module.cc
@@ -510,18 +510,14 @@ class RelayBuildModule : public runtime::ModuleNode {
         // If we cannot decide the target is LLVM, we create an empty 
CSourceModule.
         // The code content is initialized with ";" to prevent complaining
         // from CSourceModuleNode::SaveToFile.
-        ret_.mod = tvm::codegen::CSourceModuleCreate(";", "");
+        ret_.mod = tvm::codegen::CSourceModuleCreate(";", "", Array<String>{});
       }
     } else {
       ret_.mod = tvm::build(lowered_funcs, target_host_);
     }
 
-    Array<tvm::runtime::Module> ext_mods = 
graph_codegen_->GetExternalModules();
-    // TODO(zhiics) We should be able to completely switch to MetadataModule no
-    // matter whether there are external modules or not.
-    if (!ext_mods.empty()) {
-      ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, 
ext_mods);
-    }
+    auto ext_mods = graph_codegen_->GetExternalModules();
+    ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, 
ext_mods, GetTargetHost());
   }
 
  private:
diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc 
b/src/relay/backend/contrib/codegen_c/codegen.cc
index 935ac16..998393d 100644
--- a/src/relay/backend/contrib/codegen_c/codegen.cc
+++ b/src/relay/backend/contrib/codegen_c/codegen.cc
@@ -215,20 +215,19 @@ class CodegenC : public 
MemoizedExprTranslator<std::vector<Output>>, public Code
 
 class CSourceCodegen : public CSourceModuleCodegenBase {
  public:
-  std::pair<std::string, Array<String>> GenCFunc(const Function& func) {
+  std::tuple<Array<String>, String, String> GenCFunc(const Function& func) {
     ICHECK(func.defined()) << "Input error: expect a Relay function.";
-
-    // Record the external symbol for runtime lookup.
-    auto sid = GetExtSymbol(func);
-
-    CodegenC builder(sid);
+    CodegenC builder(GetExtSymbol(func));
     auto out = builder.VisitExpr(func->body);
-    code_stream_ << builder.JIT(out);
-
-    return {sid, builder.const_vars_};
+    return std::make_tuple(builder.const_vars_, builder.ext_func_id_, 
builder.JIT(out));
   }
 
   runtime::Module CreateCSourceModule(const ObjectRef& ref) override {
+    ICHECK(ref->IsInstance<FunctionNode>());
+    auto res = GenCFunc(Downcast<Function>(ref));
+    Array<String> variables = std::get<0>(res);
+    String func_name = std::get<1>(res);
+
     // Create headers
     code_stream_ << "#include <cstring>\n";
     code_stream_ << "#include <vector>\n";
@@ -259,18 +258,13 @@ class CSourceCodegen : public CSourceModuleCodegenBase {
     )op_macro";
 
     code_stream_ << operator_macro << "\n\n";
-
-    ICHECK(ref->IsInstance<FunctionNode>());
-    auto res = GenCFunc(Downcast<Function>(ref));
+    code_stream_ << std::get<2>(res);
     std::string code = code_stream_.str();
 
-    String sym = std::get<0>(res);
-    Array<String> variables = std::get<1>(res);
-
     // Create a CSource module
     const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate");
     ICHECK(pf != nullptr) << "Cannot find csource module to create the 
external runtime module";
-    return (*pf)(code, "c", sym, variables);
+    return (*pf)(code, "c", Array<String>{func_name}, variables);
   }
 
  private:
diff --git a/src/relay/backend/contrib/dnnl/codegen.cc 
b/src/relay/backend/contrib/dnnl/codegen.cc
index bfc5c77..c9a5828 100644
--- a/src/relay/backend/contrib/dnnl/codegen.cc
+++ b/src/relay/backend/contrib/dnnl/codegen.cc
@@ -413,7 +413,8 @@ class DNNLModuleCodegen : public CSourceModuleCodegenBase {
     // Create a CSource module
     const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate");
     ICHECK(pf != nullptr) << "Cannot find csource module to create the 
external runtime module";
-    return (*pf)(code, "c", sym, variables);
+    // TODO(@manupa-arm): pass the function names to enable system-lib creation
+    return (*pf)(code, "c", Array<String>{sym}, variables);
   }
 
  private:
diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc
index bed2510..8fbe31e 100644
--- a/src/relay/backend/vm/compiler.cc
+++ b/src/relay/backend/vm/compiler.cc
@@ -1146,11 +1146,9 @@ void VMCompiler::Codegen() {
   } else {
     // There is no function handled by TVM. We create a virtual main module
     // to make sure a DSO module will be also available.
-    exec_->lib = codegen::CSourceModuleCreate(";", "");
-  }
-  if (!ext_mods.empty()) {
-    exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods);
+    exec_->lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
   }
+  exec_->lib = codegen::CreateMetadataModule(params_, exec_->lib, ext_mods, 
target_host_);
 }
 
 ExprDeviceMap VMCompiler::AnalyzeContext() const {
diff --git a/src/target/func_registry_generator.cc 
b/src/target/func_registry_generator.cc
index 402d0f8..7c948d5 100644
--- a/src/target/func_registry_generator.cc
+++ b/src/target/func_registry_generator.cc
@@ -29,7 +29,7 @@
 namespace tvm {
 namespace target {
 
-std::string GenerateFuncRegistryNames(const std::vector<std::string>& 
function_names) {
+std::string GenerateFuncRegistryNames(const Array<String>& function_names) {
   std::stringstream ss;
   ss << (unsigned char)(function_names.size());
   for (auto f : function_names) {
diff --git a/src/target/func_registry_generator.h 
b/src/target/func_registry_generator.h
index 362fca8..fb59648 100644
--- a/src/target/func_registry_generator.h
+++ b/src/target/func_registry_generator.h
@@ -24,13 +24,18 @@
 #ifndef TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_
 #define TVM_TARGET_FUNC_REGISTRY_GENERATOR_H_
 
+#include <tvm/runtime/container.h>
+
 #include <string>
 #include <vector>
 
+using tvm::runtime::Array;
+using tvm::runtime::String;
+
 namespace tvm {
 namespace target {
 
-std::string GenerateFuncRegistryNames(const std::vector<std::string>& 
function_names);
+std::string GenerateFuncRegistryNames(const Array<String>& function_names);
 
 }  // namespace target
 }  // namespace tvm
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index fea5f80..6143e70 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -794,10 +794,10 @@ llvm::Value* CodeGenCPU::RuntimeTVMParallelBarrier() {
 void CodeGenCPU::AddStartupFunction() {
   if (registry_functions_.size() != 0) {
     ICHECK(is_system_lib_) << "Loading of --system-lib modules is yet to be 
defined for C runtime";
-    std::vector<std::string> symbols;
+    Array<String> symbols;
     std::vector<llvm::Constant*> funcs;
     for (auto sym : registry_functions_) {
-      symbols.emplace_back(sym.first);
+      symbols.push_back(sym.first);
       funcs.emplace_back(llvm::ConstantExpr::getBitCast(
           sym.second, ftype_tvm_backend_packed_c_func_->getPointerTo()));
     }
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index a0ab49d..43d2097 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -60,6 +60,13 @@ class LLVMModuleNode final : public runtime::ModuleNode {
     if (name == "__tvm_is_system_module") {
       bool flag = (mptr_->getFunction("__tvm_module_startup") != nullptr);
       return PackedFunc([flag](TVMArgs args, TVMRetValue* rv) { *rv = flag; });
+    } else if (name == "get_func_names") {
+      return PackedFunc(
+          [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->function_names_; });
+    } else if (name == "get_symbol") {
+      return PackedFunc(nullptr);
+    } else if (name == "get_const_vars") {
+      return PackedFunc(nullptr);
     } else if (name == "_get_target_triple") {
       std::string target_triple = tm_->getTargetTriple().str();
       // getTargetTriple() doesn't include other flags besides the triple. Add 
back flags which are
@@ -218,9 +225,10 @@ class LLVMModuleNode final : public runtime::ModuleNode {
       ICHECK(kv.second->IsInstance<PrimFuncNode>())
           << "Can only lower IR Module with PrimFuncs, but got " << 
kv.second->GetTypeKey();
       auto f = Downcast<PrimFunc>(kv.second);
+      auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+      ICHECK(global_symbol.defined());
+      function_names_.push_back(global_symbol.value());
       if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
-        auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-        ICHECK(global_symbol.defined());
         entry_func = global_symbol.value();
       }
       funcs.push_back(f);
@@ -377,6 +385,8 @@ class LLVMModuleNode final : public runtime::ModuleNode {
   std::unique_ptr<llvm::Module> module_;
   // the context.
   std::shared_ptr<llvm::LLVMContext> ctx_;
+  /* \brief names of the functions declared in this module */
+  Array<String> function_names_;
 };
 
 TVM_REGISTER_GLOBAL("target.build.llvm")
diff --git a/src/target/source/codegen_c_host.cc 
b/src/target/source/codegen_c_host.cc
index 0a19fc1..bee5441 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -55,7 +55,7 @@ void CodeGenCHost::AddFunction(const PrimFunc& f) {
   auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
   ICHECK(global_symbol.defined())
       << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute";
-  function_names_.emplace_back(global_symbol.value());
+  function_names_.push_back(global_symbol.value());
 
   CodeGenC::AddFunction(f);
 }
@@ -73,7 +73,7 @@ void CodeGenCHost::LinkParameters(Map<String, LinkedParam> 
params) {
          << "        out_ret_tcode[0] = " << kTVMNullptr << ";\n"
          << "        return 0;\n";
 
-  function_names_.emplace_back(tvm::runtime::symbol::tvm_lookup_linked_param);
+  function_names_.push_back(tvm::runtime::symbol::tvm_lookup_linked_param);
   for (auto kv : params) {
     decl_stream << "\n"
                 << "#ifdef __cplusplus\n"
@@ -322,29 +322,6 @@ inline void CodeGenCHost::PrintTernaryCondExpr(const T* 
op, const char* compare,
      << "? (" << a_id << ") : (" << b_id << "))";
 }
 
-void CodeGenCHost::GenerateFuncRegistry() {
-  decl_stream << "#include <tvm/runtime/crt/module.h>\n";
-  stream << "static TVMBackendPackedCFunc _tvm_func_array[] = {\n";
-  for (auto f : function_names_) {
-    stream << "    (TVMBackendPackedCFunc)" << f << ",\n";
-  }
-  stream << "};\n";
-  auto registry = target::GenerateFuncRegistryNames(function_names_);
-  stream << "static const TVMFuncRegistry _tvm_func_registry = {\n"
-         << "    \"" << ::tvm::support::StrEscape(registry.data(), 
registry.size(), true) << "\","
-         << "    _tvm_func_array,\n"
-         << "};\n";
-}
-
-void CodeGenCHost::GenerateCrtSystemLib() {
-  stream << "static const TVMModule _tvm_system_lib = {\n"
-         << "    &_tvm_func_registry,\n"
-         << "};\n"
-         << "const TVMModule* TVMSystemLibEntryPoint(void) {\n"
-         << "    return &_tvm_system_lib;\n"
-         << "}\n";
-}
-
 runtime::Module BuildCHost(IRModule mod, Target target) {
   using tvm::runtime::Registry;
   bool output_ssa = false;
@@ -380,12 +357,10 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
   if (target->GetAttr<Bool>("system-lib").value_or(Bool(false))) {
     ICHECK_EQ(target->GetAttr<String>("runtime").value_or(""), "c")
         << "c target only supports generating C runtime SystemLibs";
-    cg.GenerateFuncRegistry();
-    cg.GenerateCrtSystemLib();
   }
 
   std::string code = cg.Finish();
-  return CSourceModuleCreate(code, "c");
+  return CSourceModuleCreate(code, "c", cg.GetFunctionNames());
 }
 
 TVM_REGISTER_GLOBAL("target.build.c").set_body_typed(BuildCHost);
diff --git a/src/target/source/codegen_c_host.h 
b/src/target/source/codegen_c_host.h
index b54b6fb..97fe7ab 100644
--- a/src/target/source/codegen_c_host.h
+++ b/src/target/source/codegen_c_host.h
@@ -59,18 +59,14 @@ class CodeGenCHost final : public CodeGenC {
 
   void VisitStmt_(const AssertStmtNode* op) final;  // NOLINT(*)
 
-  /*! \brief Generate C runtime FuncRegistry global constant. */
-  void GenerateFuncRegistry();
-
-  /*! \brief Generate C runtime SystemLib entry point. */
-  void GenerateCrtSystemLib();
+  Array<String> GetFunctionNames() { return function_names_; }
 
  private:
   std::string module_name_;
   /* \brief tracks declared global variables which live despite GetUniqueName 
*/
   std::set<std::string> declared_globals_;
   /* \brief names of the functions declared in this module */
-  std::vector<std::string> function_names_;
+  Array<String> function_names_;
   /*! \brief whether to emit asserts in the resulting C code */
   bool emit_asserts_;
 
diff --git a/src/target/source/codegen_source_base.h 
b/src/target/source/codegen_source_base.h
index 7e5e403..ed838f8 100644
--- a/src/target/source/codegen_source_base.h
+++ b/src/target/source/codegen_source_base.h
@@ -136,25 +136,26 @@ runtime::Module SourceModuleCreate(std::string code, 
std::string fmt);
  * \brief Create a C source module for viewing and compiling GCC code.
  * \param code The code to be viewed.
  * \param fmt The code format.
- * \param symbol The symbol that the c source module represents.
+ * \param func_names The name of functions inside the runtime module.
  * \param const_vars. The constant variables that the c source module needs.
  * \return The created module.
  */
 runtime::Module CSourceModuleCreate(const String& code, const String& fmt,
-                                    const String& symbol = "",
+                                    const Array<String>& func_names,
                                     const Array<String>& const_vars = {});
 
 /*!
  * \brief Wrap the submodules in a metadata module.
  * \param params The variable to constant mapping that is collected by the host
  *        module.
- * \param dso_module The host module to be wrapped.
- * \param modules The modules to be wrapped.
+ * \param target_module The main TIR-lowered internal runtime module
+ * \param modules All the external modules that needs to be imported inside 
the metadata module(s).
+ * \param target The target that all the modules are compiled for
  * \return The wrapped module.
  */
 runtime::Module CreateMetadataModule(
-    const std::unordered_map<std::string, runtime::NDArray>& params,
-    const runtime::Module& dso_module, const Array<runtime::Module>& modules);
+    const std::unordered_map<std::string, runtime::NDArray>& params, 
runtime::Module target_module,
+    const Array<runtime::Module>& ext_modules, Target target);
 
 /*!
  * \brief Create a source module for viewing and limited saving for device.
@@ -167,6 +168,15 @@ runtime::Module CreateMetadataModule(
 runtime::Module DeviceSourceModuleCreate(
     std::string data, std::string fmt, std::unordered_map<std::string, 
runtime::FunctionInfo> fmap,
     std::string type_key, std::function<std::string(const std::string&)> 
fget_source = nullptr);
+
+/*!
+ * \brief Wrap the submodules that are to be wrapped in a c-source metadata 
module.
+ * \param modules The modules to be wrapped.
+ * \param target the target the modules are compiled for.
+ * \return The wrapped module.
+ */
+runtime::Module CreateCSourceMetadataModule(const Array<runtime::Module>& 
modules, Target target);
+
 }  // namespace codegen
 }  // namespace tvm
 #endif  // TVM_TARGET_SOURCE_CODEGEN_SOURCE_BASE_H_
diff --git a/src/target/source/source_module.cc 
b/src/target/source/source_module.cc
index 3be658a..4b4770a 100644
--- a/src/target/source/source_module.cc
+++ b/src/target/source/source_module.cc
@@ -27,6 +27,8 @@
 
 #include "../../runtime/file_utils.h"
 #include "../../runtime/meta_data.h"
+#include "../../support/str_escape.h"
+#include "../func_registry_generator.h"
 #include "codegen_source_base.h"
 
 namespace tvm {
@@ -46,40 +48,66 @@ using runtime::SaveBinaryToFile;
  *        codegens, such as graph runtime codegen and the vm compiler.
  *
  * \param params The metadata for initialization of all modules.
- * \param dso_module The DSO module that contains TVM primitives.
- * \param modules The submodules that will be wrapped, e.g. CSource modules 
that
- *        contain vendor library calls or customized runtime modules.
- *
+ * \param target_module the internal module that is compiled by tvm.
+ * \param ext_modules The external modules that needs to be imported inside 
the metadata
+ * module(s).
+ * \param target The target that all the modules are compiled for
  * \return The created metadata module that manages initialization of metadata.
  */
 runtime::Module CreateMetadataModule(
     const std::unordered_map<std::string, runtime::NDArray>& params,
-    const runtime::Module& dso_module, const Array<runtime::Module>& modules) {
+    tvm::runtime::Module target_module, const Array<runtime::Module>& 
ext_modules, Target target) {
+  Array<tvm::runtime::Module> csource_modules;
+  Array<tvm::runtime::Module> binary_modules;
+
+  auto DSOExportable = [](tvm::runtime::Module& mod) {
+    return !std::strcmp(mod->type_key(), "llvm") || 
!std::strcmp(mod->type_key(), "c");
+  };
+
   // Wrap all submodules in the initialization wrapper.
   std::unordered_map<std::string, std::vector<std::string>> sym_metadata;
-  for (runtime::Module it : modules) {
-    auto pf_sym = it.GetFunction("get_symbol");
-    auto pf_var = it.GetFunction("get_const_vars");
+  for (tvm::runtime::Module mod : ext_modules) {
+    auto pf_sym = mod.GetFunction("get_symbol");
+    auto pf_var = mod.GetFunction("get_const_vars");
+    std::vector<std::string> arrays;
     if (pf_sym != nullptr && pf_var != nullptr) {
       String symbol = pf_sym();
       Array<String> variables = pf_var();
-      std::vector<std::string> arrays;
       for (size_t i = 0; i < variables.size(); i++) {
         arrays.push_back(variables[i].operator std::string());
       }
       ICHECK_EQ(sym_metadata.count(symbol), 0U) << "Found duplicated symbol: " 
<< symbol;
       sym_metadata[symbol] = arrays;
     }
+    // We only need loading of serialized constant data
+    // if there are constants present and required by the
+    // runtime module to be initialized by the binary
+    // metadata module. If not rest of the modules are
+    // wrapped in c-source metadata module.
+
+    // TODO(@manupa-arm) : we should be able to use csource_metadata
+    // if the variables are empty when all the runtime modules implement 
get_func_names
+    if (arrays.empty() && DSOExportable(mod) && target->kind->name == "c") {
+      csource_modules.push_back(mod);
+    } else {
+      binary_modules.push_back(mod);
+    }
   }
 
-  // Wrap the modules.
-  runtime::Module init_m = runtime::MetadataModuleCreate(params, sym_metadata);
-  init_m.Import(dso_module);
-  for (const auto& it : modules) {
-    init_m.Import(it);
+  if (target.defined() && target->kind->name == "c") {
+    csource_modules.push_back(target_module);
+    target_module = CreateCSourceMetadataModule(csource_modules, target);
   }
 
-  return init_m;
+  if (!binary_modules.empty()) {
+    runtime::Module binary_meta_mod = runtime::MetadataModuleCreate(params, 
sym_metadata);
+    binary_meta_mod.Import(target_module);
+    for (const auto& it : binary_modules) {
+      binary_meta_mod.Import(it);
+    }
+    return binary_meta_mod;
+  }
+  return target_module;
 }
 
 // Simulator function
@@ -109,18 +137,25 @@ runtime::Module SourceModuleCreate(std::string code, 
std::string fmt) {
 // Simulator function
 class CSourceModuleNode : public runtime::ModuleNode {
  public:
-  CSourceModuleNode(const std::string& code, const std::string& fmt, const 
std::string& symbol,
-                    const Array<String>& const_vars)
-      : code_(code), fmt_(fmt), symbol_(symbol), const_vars_(const_vars) {}
+  CSourceModuleNode(const std::string& code, const std::string& fmt,
+                    const Array<String>& func_names, const Array<String>& 
const_vars)
+      : code_(code), fmt_(fmt), const_vars_(const_vars), 
func_names_(func_names) {}
   const char* type_key() const { return "c"; }
 
   PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& 
sptr_to_self) final {
+    // Currently c-source module is used as demonstration purposes with binary 
metadata module
+    // that expects get_symbol interface. When c-source module is used as 
external module, it
+    // will only contain one function. However, when its used as an internal 
module (e.g., target
+    // "c") it can have many functions.
     if (name == "get_symbol") {
       return PackedFunc(
-          [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->symbol_; });
+          [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->func_names_[0]; });
     } else if (name == "get_const_vars") {
       return PackedFunc(
           [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->const_vars_; });
+    } else if (name == "get_func_names") {
+      return PackedFunc(
+          [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->func_names_; });
     } else {
       return PackedFunc(nullptr);
     }
@@ -131,7 +166,7 @@ class CSourceModuleNode : public runtime::ModuleNode {
   void SaveToFile(const std::string& file_name, const std::string& format) 
final {
     std::string fmt = GetFileFormat(file_name, format);
     std::string meta_file = GetMetaFilePath(file_name);
-    if (fmt == "cc") {
+    if (fmt == "c") {
       ICHECK_NE(code_.length(), 0);
       SaveBinaryToFile(file_name, code_);
     } else {
@@ -142,17 +177,109 @@ class CSourceModuleNode : public runtime::ModuleNode {
  protected:
   std::string code_;
   std::string fmt_;
-  std::string symbol_;
   Array<String> const_vars_;
+  Array<String> func_names_;
 };
 
-runtime::Module CSourceModuleCreate(const String& code, const String& fmt, 
const String& symbol,
+runtime::Module CSourceModuleCreate(const String& code, const String& fmt,
+                                    const Array<String>& func_names,
                                     const Array<String>& const_vars) {
   auto n = make_object<CSourceModuleNode>(code.operator std::string(), 
fmt.operator std::string(),
-                                          symbol.operator std::string(), 
const_vars);
+                                          func_names, const_vars);
   return runtime::Module(n);
 }
 
+class CSourceMetadataModuleNode : public runtime::ModuleNode {
+ public:
+  CSourceMetadataModuleNode(const Array<String>& func_names, const 
std::string& fmt, Target target)
+      : fmt_(fmt), func_names_(func_names), target_(target) {
+    CreateSource();
+  }
+  const char* type_key() const { return "c"; }
+
+  std::string GetSource(const std::string& format) final { return code_.str(); 
}
+
+  PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& 
sptr_to_self) final {
+    return PackedFunc(nullptr);
+  }
+
+  void SaveToFile(const std::string& file_name, const std::string& format) 
final {
+    std::string fmt = GetFileFormat(file_name, format);
+    std::string meta_file = GetMetaFilePath(file_name);
+    if (fmt == "c") {
+      auto code_str = code_.str();
+      ICHECK_NE(code_str.length(), 0);
+      SaveBinaryToFile(file_name, code_str);
+    } else {
+      ICHECK_EQ(fmt, fmt_) << "Can only save to format=" << fmt_;
+    }
+  }
+
+ protected:
+  std::stringstream code_;
+  std::string fmt_;
+  Array<String> func_names_;
+  Target target_;
+
+  void CreateFuncRegistry() {
+    code_ << "#include <tvm/runtime/crt/module.h>\n";
+    for (const auto& fname : func_names_) {
+      code_ << "#ifdef __cplusplus\n";
+      code_ << "extern \"C\"\n";
+      code_ << "#endif\n";
+      code_ << "TVM_DLL int32_t " << fname.data();
+      code_ << "(TVMValue* args, int* type_code, int num_args, TVMValue* 
out_value, int* "
+               "out_type_code);\n";
+    }
+    code_ << "static TVMBackendPackedCFunc _tvm_func_array[] = {\n";
+    for (auto f : func_names_) {
+      code_ << "    (TVMBackendPackedCFunc)" << f << ",\n";
+    }
+    code_ << "};\n";
+    auto registry = target::GenerateFuncRegistryNames(func_names_);
+    code_ << "static const TVMFuncRegistry _tvm_func_registry = {\n"
+          << "    \"" << ::tvm::support::StrEscape(registry.data(), 
registry.size(), true) << "\","
+          << "    _tvm_func_array,\n"
+          << "};\n";
+  }
+
+  void GenerateCrtSystemLib() {
+    code_ << "static const TVMModule _tvm_system_lib = {\n"
+          << "    &_tvm_func_registry,\n"
+          << "};\n"
+          << "const TVMModule* TVMSystemLibEntryPoint(void) {\n"
+          << "    return &_tvm_system_lib;\n"
+          << "}\n";
+  }
+
+  void CreateSource() {
+    if (target_->GetAttr<Bool>("system-lib").value_or(Bool(false)) && 
!func_names_.empty()) {
+      CreateFuncRegistry();
+      GenerateCrtSystemLib();
+    }
+    code_ << ";";
+  }
+};
+
+runtime::Module CreateCSourceMetadataModule(const Array<runtime::Module>& 
modules, Target target) {
+  Array<String> func_names;
+  for (runtime::Module mod : modules) {
+    auto pf_funcs = mod.GetFunction("get_func_names");
+    if (pf_funcs != nullptr) {
+      Array<String> func_names_ = pf_funcs();
+      for (const auto& fname : func_names_) {
+        func_names.push_back(fname);
+      }
+    }
+  }
+  auto n = make_object<CSourceMetadataModuleNode>(func_names, "cc", target);
+  auto csrc_metadata_module = runtime::Module(n);
+  for (const auto& mod : modules) {
+    csrc_metadata_module.Import(mod);
+  }
+  return std::move(csrc_metadata_module);
+}
+
 // supports limited save without cross compile
 class DeviceSourceModuleNode final : public runtime::ModuleNode {
  public:
@@ -209,8 +336,14 @@ runtime::Module DeviceSourceModuleCreate(
 
TVM_REGISTER_GLOBAL("runtime.SourceModuleCreate").set_body_typed(SourceModuleCreate);
 
 TVM_REGISTER_GLOBAL("runtime.CSourceModuleCreate")
-    .set_body_typed([](String code, String fmt, String symbol, Array<String> 
const_vars) {
-      return CSourceModuleCreate(code, fmt, symbol, const_vars);
+    .set_body_typed([](String code, String fmt, Array<String> func_names,
+                       Array<String> const_vars) {
+      return CSourceModuleCreate(code, fmt, func_names, const_vars);
+    });
+
+TVM_REGISTER_GLOBAL("runtime.CreateCSourceMetadataModule")
+    .set_body_typed([](const Array<runtime::Module>& modules, Target target) {
+      return CreateCSourceMetadataModule(modules, target);
     });
 
 }  // namespace codegen
diff --git a/tests/micro/qemu/test_zephyr.py b/tests/micro/qemu/test_zephyr.py
index 2213203..3e73307 100644
--- a/tests/micro/qemu/test_zephyr.py
+++ b/tests/micro/qemu/test_zephyr.py
@@ -29,7 +29,7 @@ import numpy as np
 import tvm
 import tvm.rpc
 import tvm.micro
-import tvm.relay
+import tvm.relay as relay
 
 from tvm.micro.contrib import zephyr
 from tvm.contrib import utils
@@ -143,5 +143,33 @@ def test_compile_runtime(platform):
         test_basic_add(sess)
 
 
+def test_relay(platform):
+    """Testing a simple relay graph"""
+    model, zephyr_board = PLATFORMS[platform]
+    shape = (10,)
+    dtype = "int8"
+
+    # Construct Relay program.
+    x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype))
+    xx = relay.multiply(x, x)
+    z = relay.add(xx, relay.const(np.ones(shape=shape, dtype=dtype)))
+    func = relay.Function([x], z)
+
+    target = tvm.target.target.micro(model)
+    with tvm.transform.PassContext(opt_level=3, 
config={"tir.disable_vectorize": True}):
+        graph, mod, params = tvm.relay.build(func, target=target)
+
+    with _make_session(model, target, zephyr_board, mod) as session:
+        graph_mod = tvm.micro.create_local_graph_runtime(
+            graph, session.get_system_lib(), session.context
+        )
+        graph_mod.set_input(**params)
+        x_in = np.random.randint(10, size=shape[0], dtype=dtype)
+        graph_mod.run(x=x_in)
+        result = graph_mod.get_output(0).asnumpy()
+        tvm.testing.assert_allclose(graph_mod.get_input(0).asnumpy(), x_in)
+        tvm.testing.assert_allclose(result, x_in * x_in + 1)
+
+
 if __name__ == "__main__":
     sys.exit(pytest.main([os.path.dirname(__file__)] + sys.argv[1:]))
diff --git a/tests/python/relay/test_pass_partition_graph.py 
b/tests/python/relay/test_pass_partition_graph.py
index 059d0b4..d8f674e 100644
--- a/tests/python/relay/test_pass_partition_graph.py
+++ b/tests/python/relay/test_pass_partition_graph.py
@@ -273,19 +273,23 @@ def test_multi_node_compiler():
 
     map_inputs = {"w{}".format(i): w_data[i] for i in range(8)}
     map_inputs["x"] = x_data
-    check_result(
-        mod,
-        map_inputs,
-        (30, 10),
-        np.concatenate(
-            (
-                ((x_data + w_data[0]) - w_data[1]) * w_data[2],
-                ((x_data + w_data[3]) - w_data[4]) * w_data[5],
-                x_data + w_data[6] - w_data[7],
+
+    targets = ["llvm", "c -runtime=c --system-lib"]
+    for tgt in targets:
+        check_result(
+            mod,
+            map_inputs,
+            (30, 10),
+            np.concatenate(
+                (
+                    ((x_data + w_data[0]) - w_data[1]) * w_data[2],
+                    ((x_data + w_data[3]) - w_data[4]) * w_data[5],
+                    x_data + w_data[6] - w_data[7],
+                ),
+                axis=0,
             ),
-            axis=0,
-        ),
-    )
+            target=tgt,
+        )
 
 
 def test_extern_ccompiler_single_op():
diff --git a/tests/python/unittest/test_crt.py 
b/tests/python/unittest/test_crt.py
index 07a4cfc..1d84d4e 100644
--- a/tests/python/unittest/test_crt.py
+++ b/tests/python/unittest/test_crt.py
@@ -49,7 +49,6 @@ def _make_sess_from_op(workspace, op_name, sched, arg_bufs):
 def _make_session(workspace, mod):
     compiler = tvm.micro.DefaultCompiler(target=TARGET)
     opts = tvm.micro.default_options(os.path.join(tvm.micro.CRT_ROOT_DIR, 
"host"))
-
     micro_binary = tvm.micro.build_static_runtime(
         # the x86 compiler *expects* you to give the exact same dictionary for 
both
         # lib_opts and bin_opts. so the library compiler is mutating lib_opts 
and
diff --git a/tests/python/unittest/test_link_params.py 
b/tests/python/unittest/test_link_params.py
index e3bd634..da87a31 100644
--- a/tests/python/unittest/test_link_params.py
+++ b/tests/python/unittest/test_link_params.py
@@ -266,7 +266,7 @@ def test_c_link_params():
             lib = tvm.relay.build(mod, target, params=param_init)
             assert set(lib.params.keys()) == {"p0", "p1"}  # NOTE: op folded
 
-            src = lib.lib.get_source()
+            src = lib.lib.imported_modules[0].get_source()
             lib.lib.save("test.c", "cc")
             c_dtype = _get_c_datatype(dtype)
             src_lines = src.split("\n")
diff --git a/tests/python/unittest/test_runtime_module_export.py 
b/tests/python/unittest/test_runtime_module_export.py
index 88b7af9..af9a8ab 100644
--- a/tests/python/unittest/test_runtime_module_export.py
+++ b/tests/python/unittest/test_runtime_module_export.py
@@ -58,7 +58,7 @@ def generate_engine_module():
     import tvm.runtime._ffi_api
 
     gen_engine_header()
-    csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", "", 
None)
+    csource_module = tvm.runtime._ffi_api.CSourceModuleCreate(code, "cc", [], 
None)
     return csource_module
 
 

Reply via email to