This is an automated email from the ASF dual-hosted git repository.

mousius pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5a58c581f5 Added macro generation in MLF export (#12789)
5a58c581f5 is described below

commit 5a58c581f5e0272a42a5b68ed78c400138fc0082
Author: fPecc <[email protected]>
AuthorDate: Wed Dec 7 12:11:00 2022 +0100

    Added macro generation in MLF export (#12789)
    
    The generated MLF header files for each module contain the struct 
definition to use as input and outputs to call the generated function. If we 
want to call this tvmgen_default_run, we need to allocate space (statically or 
dynamically) for the input and output tensors. This generates macros that 
define the size of each input and output in bytes, this allows us to reference 
this new macros to statically or dynamically allocate vectors to store the 
inputs and outputs of the tvmgen_defaul [...]
    
    
    Co-authored-by: Federico Peccia <[email protected]>
    Co-authored-by: Christopher Sidebottom <[email protected]>
---
 python/tvm/micro/model_library_format.py           |  68 ++++-
 src/target/source/interface_c.cc                   |  34 ++-
 tests/cpp/target/source/interface_c_test.cc        | 316 +++++++++++++++++----
 tests/micro/zephyr/utils.py                        |  11 +-
 .../unittest/test_micro_model_library_format.py    |  59 ++--
 5 files changed, 412 insertions(+), 76 deletions(-)

diff --git a/python/tvm/micro/model_library_format.py 
b/python/tvm/micro/model_library_format.py
index 1ba9f5e733..5aa2d154ba 100644
--- a/python/tvm/micro/model_library_format.py
+++ b/python/tvm/micro/model_library_format.py
@@ -47,7 +47,16 @@ class UnsupportedInModelLibraryFormatError(Exception):
 
 
 def generate_c_interface_header(
-    module_name, inputs, outputs, pools, io_pool_allocations, devices, 
workspace_size, include_path
+    module_name,
+    inputs,
+    outputs,
+    pools,
+    io_pool_allocations,
+    devices,
+    workspace_size,
+    include_path,
+    input_sizes,
+    output_sizes,
 ):
     """Generate C Interface header to be included in MLF"""
     mangled_name = to_c_variable_style(prefix_generated_name(module_name))
@@ -55,7 +64,15 @@ def generate_c_interface_header(
 
     interface_c_create = tvm._ffi.get_global_func("runtime.InterfaceCCreate")
     interface_c_module = interface_c_create(
-        module_name, inputs, outputs, pools, io_pool_allocations, devices, 
workspace_size
+        module_name,
+        inputs,
+        outputs,
+        pools,
+        io_pool_allocations,
+        devices,
+        workspace_size,
+        input_sizes,
+        output_sizes,
     )
 
     with open(metadata_header, "w") as header_file:
@@ -193,6 +210,13 @@ def _build_sid_map(graph_json):
     return memory_map
 
 
+def _create_type_metadata(input_type):
+    return {
+        "size": int(_shape_to_size(input_type.shape, input_type.dtype)),
+        "dtype": str(input_type.dtype),
+    }
+
+
 def _build_function_memory_map(function_metadata):
     """Build a simple map that shows how much workspace is required to execute
     each primitive function. The main_func describes how much memory is 
required
@@ -277,6 +301,26 @@ def _build_function_memory_map(function_metadata):
             main_func_metadata.io_sizes[target]
         )
 
+        # Now, we also add the information about the size of each input and 
output of the main
+        # function (in bytes)
+        input_dict = {}
+        for input_param in main_func_metadata.relay_primfuncs[target].params:
+            input_dict[input_param.name_hint] = 
_create_type_metadata(input_param.checked_type)
+        target_main_entries[int(target.get_target_device_type())]["inputs"] = 
input_dict
+
+        output_dict = {}
+        # For output, we dont have the name of the output, so we enumerate them
+        if isinstance(main_func_metadata.relay_primfuncs[target].ret_type, 
tvm.ir.type.TupleType):
+            output_list = _convert_tuple_to_outputs(
+                main_func_metadata.relay_primfuncs[target].ret_type
+            )
+            for i, output_type in enumerate(output_list):
+                output_dict[f"output{i}"] = _create_type_metadata(output_type)
+        else:
+            output_type = main_func_metadata.relay_primfuncs[target].ret_type
+            output_dict["output"] = _create_type_metadata(output_type)
+        target_main_entries[int(target.get_target_device_type())]["outputs"] = 
output_dict
+
     ret = {
         "operator_functions": func_entries,
         "main": list(target_main_entries.values()),
@@ -298,7 +342,7 @@ def _convert_tuple_to_outputs(ret_type, offset=0):
         if isinstance(ret_type.fields[output_index], TupleType):
             
outputs.extend(_convert_tuple_to_outputs(ret_type.fields[output_index], 
next_output))
         else:
-            outputs.append(f"output{next_output}")
+            outputs.append(ret_type.fields[output_index])
     return outputs
 
 
@@ -427,6 +471,20 @@ def _export_graph_model_library_format(
                     "workspace_size_bytes"
                 ]
             )
+            inputs_sizes = 
metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0][
+                "inputs"
+            ]
+            # Here, we merge the output sizes with the actual output names
+            output_sizes = {}
+            for i, key in enumerate(
+                
metadata["modules"][mod.libmod_name]["memory"]["functions"]["main"][0][
+                    "outputs"
+                ].keys()
+            ):
+                output_sizes[outputs[i]] = 
metadata["modules"][mod.libmod_name]["memory"][
+                    "functions"
+                ]["main"][0]["outputs"][key]
+
             generate_c_interface_header(
                 mod.libmod_name,
                 inputs,
@@ -436,6 +494,8 @@ def _export_graph_model_library_format(
                 devices,
                 workspace_size,
                 include_path,
+                inputs_sizes,
+                output_sizes,
             )
 
         is_aot = isinstance(mod, executor_factory.AOTExecutorFactoryModule)
@@ -459,7 +519,7 @@ class NonStaticShapeError(Exception):
 
 def _shape_to_size(shape, dtype):
     bits_per_item = int(
-        re.match(r"((float)|(int))(?P<width_bits>[0-9]+)", 
dtype).group("width_bits")
+        re.match(r"((float)|(int)|(uint))(?P<width_bits>[0-9]+)", 
dtype).group("width_bits")
     )
     assert bits_per_item is not None, f"don't know how to compute size of type 
{dtype}"
     total_bits = bits_per_item
diff --git a/src/target/source/interface_c.cc b/src/target/source/interface_c.cc
index ed7058f1f1..fe495b212a 100644
--- a/src/target/source/interface_c.cc
+++ b/src/target/source/interface_c.cc
@@ -47,20 +47,42 @@ class InterfaceCNode : public runtime::ModuleNode {
   InterfaceCNode(std::string module_name, Array<String> inputs, Array<String> 
outputs,
                  Array<tir::usmp::AllocatedPoolInfo> pools,
                  Map<String, tir::usmp::PoolAllocation> io_pool_allocations, 
Array<String> devices,
-                 int workspace_size)
+                 int workspace_size, Map<String, IntImm> input_sizes,
+                 Map<String, IntImm> output_sizes)
       : module_name_(module_name),
         inputs_(inputs),
         outputs_(outputs),
         devices_(devices),
         pools_(FilterExternalPools(pools)),
         io_pool_allocations_(io_pool_allocations),
-        workspace_size_(workspace_size) {}
+        workspace_size_(workspace_size),
+        input_sizes_(input_sizes),
+        output_sizes_(output_sizes) {}
   const char* type_key() const final { return "h"; }
 
   std::string GetSource(const std::string& format) final {
     std::stringstream code;
 
     EmitUpperHeaderGuard(code);
+
+    // Emit macros for input sizes
+    for (auto const& it : input_sizes_) {
+      std::string input_name = SanitizeName(it.first);
+      std::string input_macro_name = input_name + "_size";
+      int input_size = it.second->value;
+      EmitIntegerValueMacro(code, "Input tensor " + input_name + " size (in 
bytes)",
+                            input_macro_name, input_size);
+    }
+
+    // Emit macros for output sizes
+    for (auto const& it : output_sizes_) {
+      std::string output_name = SanitizeName(it.first);
+      std::string output_macro_name = output_name + "_size";
+      int output_size = it.second->value;
+      EmitIntegerValueMacro(code, "Output tensor " + output_name + " size (in 
bytes)",
+                            output_macro_name, output_size);
+    }
+
     EmitBrief(code, "Input tensor pointers");
     EmitStruct(code, "inputs", inputs_);
     EmitBrief(code, "Output tensor pointers");
@@ -278,14 +300,18 @@ class InterfaceCNode : public runtime::ModuleNode {
   Array<tir::usmp::AllocatedPoolInfo> pools_;
   Map<String, tir::usmp::PoolAllocation> io_pool_allocations_;
   int workspace_size_;
+  Map<String, IntImm> input_sizes_;
+  Map<String, IntImm> output_sizes_;
 };
 
 runtime::Module InterfaceCCreate(std::string module_name, Array<String> inputs,
                                  Array<String> outputs, 
Array<tir::usmp::AllocatedPoolInfo> pools,
                                  Map<String, tir::usmp::PoolAllocation> 
io_pool_allocations,
-                                 Array<String> devices, int workspace_size) {
+                                 Array<String> devices, int workspace_size,
+                                 Map<String, IntImm> input_sizes,
+                                 Map<String, IntImm> output_sizes) {
   auto n = make_object<InterfaceCNode>(module_name, inputs, outputs, pools, 
io_pool_allocations,
-                                       devices, workspace_size);
+                                       devices, workspace_size, input_sizes, 
output_sizes);
   return runtime::Module(n);
 }
 
diff --git a/tests/cpp/target/source/interface_c_test.cc 
b/tests/cpp/target/source/interface_c_test.cc
index d575bfeaf0..d9d9d80bbe 100644
--- a/tests/cpp/target/source/interface_c_test.cc
+++ b/tests/cpp/target/source/interface_c_test.cc
@@ -33,7 +33,8 @@ namespace codegen {
 runtime::Module InterfaceCCreate(std::string module_name, Array<String> inputs,
                                  Array<String> outputs, 
Array<tir::usmp::AllocatedPoolInfo> pools,
                                  Map<String, tir::usmp::PoolAllocation> 
io_pool_allocations,
-                                 Array<String> devices, int workspace_size);
+                                 Array<String> devices, int workspace_size,
+                                 Map<String, IntImm> input_sizes, Map<String, 
IntImm> output_sizes);
 
 namespace {
 
@@ -53,8 +54,13 @@ TEST(InterfaceAPI, ContainsHeaderGuards) {
                      << "#endif\n\n"
                      << "#endif // TVMGEN_ULTIMATE_CAT_SPOTTER_H_\n";
 
-  runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 
{}, 0);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+
+  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"}, {},
+                                                 {}, {}, 0, input_sizes, 
output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(upper_header_guard.str()));
@@ -74,8 +80,13 @@ TEST(InterfaceAPI, ContainsRunFunction) {
                << "  struct tvmgen_ultimate_cat_spotter_outputs* outputs\n"
                << ");\n";
 
-  runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 
{}, 0);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+
+  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"}, {},
+                                                 {}, {}, 0, input_sizes, 
output_sizes);
   std::string header_source = test_module->GetSource();
   ASSERT_THAT(header_source, HasSubstr(run_function.str()));
 }
@@ -95,8 +106,13 @@ TEST(InterfaceAPI, ContainsRunFunctionWithDevices) {
                << "  struct tvmgen_ultimate_cat_spotter_devices* devices\n"
                << ");\n";
 
-  runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 
{"device"}, 0);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+
+  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"}, {},
+                                                 {}, {"device"}, 0, 
input_sizes, output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(run_function.str()));
@@ -117,11 +133,17 @@ TEST(InterfaceAPI, ContainsRunFunctionWithWorkspacePools) 
{
                << "  struct tvmgen_ultimate_cat_spotter_workspace_pools* 
workspace_pools\n"
                << ");\n";
 
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+
   PoolInfo pool_info = WorkspacePoolInfo("my_memory_pool", {});
   tir::usmp::AllocatedPoolInfo allocated_pool_info =
       tir::usmp::AllocatedPoolInfo(pool_info, 100000);
-  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"},
-                                                 {allocated_pool_info}, {}, 
{}, 0);
+  runtime::Module test_module =
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, 
{allocated_pool_info}, {}, {},
+                       0, input_sizes, output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(run_function.str()));
@@ -142,6 +164,11 @@ TEST(InterfaceAPI, 
ContainsRunFunctionWithWorkspaceAndConstantPools) {
                << "  struct tvmgen_ultimate_cat_spotter_workspace_pools* 
workspace_pools\n"
                << ");\n";
 
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+
   PoolInfo pool_info = WorkspacePoolInfo("my_memory_pool", {});
   PoolInfo const_info = ConstantPoolInfo(
       "my_constant_pool", {},
@@ -151,9 +178,9 @@ TEST(InterfaceAPI, 
ContainsRunFunctionWithWorkspaceAndConstantPools) {
       tir::usmp::AllocatedPoolInfo(pool_info, 100000);
   tir::usmp::AllocatedPoolInfo allocated_const_info =
       tir::usmp::AllocatedPoolInfo(const_info, 100000);
-  runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
-                       {allocated_pool_info, allocated_const_info}, {}, {}, 0);
+  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"},
+                                                 {allocated_pool_info, 
allocated_const_info}, {},
+                                                 {}, 0, input_sizes, 
output_sizes);
   std::string header_source = test_module->GetSource();
   ASSERT_THAT(header_source, HasSubstr(run_function.str()));
   ASSERT_THAT(
@@ -186,11 +213,17 @@ TEST(InterfaceAPI, 
ContainsRunFunctionWithWorkspacePoolsAndDevices) {
                << "  struct tvmgen_ultimate_cat_spotter_devices* devices\n"
                << ");\n";
 
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+
   PoolInfo pool_info = WorkspacePoolInfo("my_memory_pool", {});
   tir::usmp::AllocatedPoolInfo allocated_pool_info =
       tir::usmp::AllocatedPoolInfo(pool_info, 100000);
-  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"},
-                                                 {allocated_pool_info}, {}, 
{"device"}, 0);
+  runtime::Module test_module =
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, 
{allocated_pool_info}, {},
+                       {"device"}, 0, input_sizes, output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(run_function.str()));
@@ -226,14 +259,20 @@ TEST(InterfaceAPI, ContainsRunFunctionWithWorkspaceIO) {
       << "  struct tvmgen_ultimate_cat_spotter_workspace_pools* 
workspace_pools\n"
       << ");\n";
 
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+
   PoolInfo pool_info = WorkspacePoolInfo("my_memory_pool", {});
   tir::usmp::AllocatedPoolInfo allocated_pool_info =
       tir::usmp::AllocatedPoolInfo(pool_info, 100000);
   tir::usmp::PoolAllocation pool_allocation_input{pool_info, 1000};
   tir::usmp::PoolAllocation pool_allocation_output{pool_info, 2000};
-  runtime::Module test_module = InterfaceCCreate(
-      "ultimate_cat_spotter", {"input"}, {"output"}, {allocated_pool_info},
-      {{"input", pool_allocation_input}, {"output", pool_allocation_output}}, 
{}, 0);
+  runtime::Module test_module =
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, 
{allocated_pool_info},
+                       {{"input", pool_allocation_input}, {"output", 
pool_allocation_output}}, {},
+                       0, input_sizes, output_sizes);
   std::string header_source = test_module->GetSource();
   std::cout << header_source << "\n";
   ASSERT_THAT(header_source, HasSubstr(run_function_with_map_functions.str()));
@@ -241,6 +280,13 @@ TEST(InterfaceAPI, ContainsRunFunctionWithWorkspaceIO) {
 
 TEST(InterfaceAPI, ContainsInputStructSingle) {
   std::stringstream input_struct;
+  std::stringstream input_size_macro;
+
+  input_size_macro
+      << "/*!\n"
+      << " * \\brief Input tensor input size (in bytes) for TVM module 
\"ultimate_cat_spotter\" \n"
+      << " */\n"
+      << "#define TVMGEN_ULTIMATE_CAT_SPOTTER_INPUT_SIZE 537\n";
 
   input_struct << "/*!\n"
                << " * \\brief Input tensor pointers for TVM module 
\"ultimate_cat_spotter\" \n"
@@ -249,51 +295,120 @@ TEST(InterfaceAPI, ContainsInputStructSingle) {
                << "  void* input;\n"
                << "};\n\n";
 
-  runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 
{}, 0);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 537));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+
+  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"}, {},
+                                                 {}, {}, 0, input_sizes, 
output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(input_struct.str()));
+
+  ASSERT_THAT(header_source, HasSubstr(input_size_macro.str()));
 }
 
 TEST(InterfaceAPI, ContainsInputStructMany) {
   std::stringstream input_struct;
+  std::stringstream input1_size_macro;
+  std::stringstream input2_size_macro;
+
+  input1_size_macro
+      << "/*!\n"
+      << " * \\brief Input tensor input1 size (in bytes) for TVM module 
\"ultimate_cat_spotter\" \n"
+      << " */\n"
+      << "#define TVMGEN_ULTIMATE_CAT_SPOTTER_INPUT1_SIZE 765\n";
+
+  input2_size_macro
+      << "/*!\n"
+      << " * \\brief Input tensor input2 size (in bytes) for TVM module 
\"ultimate_cat_spotter\" \n"
+      << " */\n"
+      << "#define TVMGEN_ULTIMATE_CAT_SPOTTER_INPUT2_SIZE 127\n";
 
   input_struct << "struct tvmgen_ultimate_cat_spotter_inputs {\n"
                << "  void* input1;\n"
                << "  void* input2;\n"
                << "};\n\n";
 
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input1", IntImm(DataType::Int(32), 765));
+  input_sizes.Set("input2", IntImm(DataType::Int(32), 127));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input1", "input2"}, 
{"output"}, {}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input1", "input2"}, 
{"output"}, {}, {}, {}, 0,
+                       input_sizes, output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(input_struct.str()));
+  ASSERT_THAT(header_source, HasSubstr(input1_size_macro.str()));
+  ASSERT_THAT(header_source, HasSubstr(input2_size_macro.str()));
 }
 
 TEST(InterfaceAPI, ContainsInputStructSanitised) {
   std::stringstream input_struct;
+  std::stringstream input1_size_macro;
+  std::stringstream input2_size_macro;
+
+  input1_size_macro << "/*!\n"
+                    << " * \\brief Input tensor input_1 size (in bytes) for 
TVM module "
+                       "\"ultimate_cat_spotter\" \n"
+                    << " */\n"
+                    << "#define TVMGEN_ULTIMATE_CAT_SPOTTER_INPUT_1_SIZE 
765\n";
+
+  input2_size_macro << "/*!\n"
+                    << " * \\brief Input tensor input_2 size (in bytes) for 
TVM module "
+                       "\"ultimate_cat_spotter\" \n"
+                    << " */\n"
+                    << "#define TVMGEN_ULTIMATE_CAT_SPOTTER_INPUT_2_SIZE 
127\n";
 
   input_struct << "struct tvmgen_ultimate_cat_spotter_inputs {\n"
                << "  void* input_1;\n"
                << "  void* input_2;\n"
                << "};\n\n";
 
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input+1", IntImm(DataType::Int(32), 765));
+  input_sizes.Set("input+2", IntImm(DataType::Int(32), 127));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input+1", "input+2"}, 
{"output"}, {}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input+1", "input+2"}, 
{"output"}, {}, {}, {}, 0,
+                       input_sizes, output_sizes);
   std::string header_source = test_module->GetSource();
 
+  std::cout << header_source << std::endl;
+
   ASSERT_THAT(header_source, HasSubstr(input_struct.str()));
+  ASSERT_THAT(header_source, HasSubstr(input1_size_macro.str()));
+  ASSERT_THAT(header_source, HasSubstr(input2_size_macro.str()));
 }
 
 TEST(InterfaceAPI, ContainsInputStructClash) {
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input+", IntImm(DataType::Int(32), 0));
+  input_sizes.Set("input-", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input+", "input-"}, 
{"output"}, {}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input+", "input-"}, 
{"output"}, {}, {}, {}, 0,
+                       input_sizes, output_sizes);
   ASSERT_THROW(test_module->GetSource(), InternalError);
 }
 
 TEST(InterfaceAPI, ContainsOutputStructSingle) {
   std::stringstream output_struct;
+  std::stringstream output_size_macro;
+
+  output_size_macro << "/*!\n"
+                    << " * \\brief Output tensor output size (in bytes) for 
TVM module "
+                       "\"ultimate_cat_spotter\" \n"
+                    << " */\n"
+                    << "#define TVMGEN_ULTIMATE_CAT_SPOTTER_OUTPUT_SIZE 543\n";
 
   output_struct << "/*!\n"
                 << " * \\brief Output tensor pointers for TVM module 
\"ultimate_cat_spotter\" \n"
@@ -302,46 +417,104 @@ TEST(InterfaceAPI, ContainsOutputStructSingle) {
                 << "  void* output;\n"
                 << "};\n\n";
 
-  runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 
{}, 0);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 543));
+
+  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"}, {},
+                                                 {}, {}, 0, input_sizes, 
output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(output_struct.str()));
+  ASSERT_THAT(header_source, HasSubstr(output_size_macro.str()));
 }
 
 TEST(InterfaceAPI, ContainsOutputStructMany) {
   std::stringstream output_struct;
+  std::stringstream output1_size_macro;
+  std::stringstream output2_size_macro;
+
+  output1_size_macro << "/*!\n"
+                     << " * \\brief Output tensor output1 size (in bytes) for 
TVM module "
+                        "\"ultimate_cat_spotter\" \n"
+                     << " */\n"
+                     << "#define TVMGEN_ULTIMATE_CAT_SPOTTER_OUTPUT1_SIZE 
345\n";
+
+  output2_size_macro << "/*!\n"
+                     << " * \\brief Output tensor output2 size (in bytes) for 
TVM module "
+                        "\"ultimate_cat_spotter\" \n"
+                     << " */\n"
+                     << "#define TVMGEN_ULTIMATE_CAT_SPOTTER_OUTPUT2_SIZE 
984\n";
 
   output_struct << "struct tvmgen_ultimate_cat_spotter_outputs {\n"
                 << "  void* output1;\n"
                 << "  void* output2;\n"
                 << "};\n\n";
 
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output1", IntImm(DataType::Int(32), 345));
+  output_sizes.Set("output2", IntImm(DataType::Int(32), 984));
+
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output1", 
"output2"}, {}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output1", 
"output2"}, {}, {}, {}, 0,
+                       input_sizes, output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(output_struct.str()));
+  ASSERT_THAT(header_source, HasSubstr(output1_size_macro.str()));
+  ASSERT_THAT(header_source, HasSubstr(output2_size_macro.str()));
 }
 
 TEST(InterfaceAPI, ContainsOutputStructSanitised) {
   std::stringstream output_struct;
+  std::stringstream output1_size_macro;
+  std::stringstream output2_size_macro;
+
+  output1_size_macro << "/*!\n"
+                     << " * \\brief Output tensor output_1 size (in bytes) for 
TVM module "
+                        "\"ultimate_cat_spotter\" \n"
+                     << " */\n"
+                     << "#define TVMGEN_ULTIMATE_CAT_SPOTTER_OUTPUT_1_SIZE 
345\n";
+
+  output2_size_macro << "/*!\n"
+                     << " * \\brief Output tensor output_2 size (in bytes) for 
TVM module "
+                        "\"ultimate_cat_spotter\" \n"
+                     << " */\n"
+                     << "#define TVMGEN_ULTIMATE_CAT_SPOTTER_OUTPUT_2_SIZE 
984\n";
 
   output_struct << "struct tvmgen_ultimate_cat_spotter_outputs {\n"
                 << "  void* output_1;\n"
                 << "  void* output_2;\n"
                 << "};\n\n";
 
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output+1", IntImm(DataType::Int(32), 345));
+  output_sizes.Set("output-2", IntImm(DataType::Int(32), 984));
+
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+1", 
"output-2"}, {}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+1", 
"output-2"}, {}, {}, {}, 0,
+                       input_sizes, output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(output_struct.str()));
+  ASSERT_THAT(header_source, HasSubstr(output1_size_macro.str()));
+  ASSERT_THAT(header_source, HasSubstr(output2_size_macro.str()));
 }
 
 TEST(InterfaceAPI, ContainsOutputStructClash) {
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output+", IntImm(DataType::Int(32), 0));
+  output_sizes.Set("output-", IntImm(DataType::Int(32), 0));
   runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+", 
"output-"}, {}, {}, {}, 0);
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+", 
"output-"}, {}, {}, {}, 0,
+                       input_sizes, output_sizes);
   ASSERT_THROW(test_module->GetSource(), InternalError);
 }
 
@@ -354,8 +527,12 @@ TEST(InterfaceAPI, NoDeviceAPIStructIfNoDevices) {
                 << "struct tvmgen_ultimate_cat_spotter_devices {\n"
                 << "};\n\n";
 
-  runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 
{}, 0);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"}, {},
+                                                 {}, {}, 0, input_sizes, 
output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, Not(HasSubstr(device_struct.str())));
@@ -371,8 +548,12 @@ TEST(InterfaceAPI, ContainsDeviceStructSingle) {
                 << "  void* device;\n"
                 << "};\n\n";
 
-  runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 
{"device"}, 0);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"}, {},
+                                                 {}, {"device"}, 0, 
input_sizes, output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(device_struct.str()));
@@ -386,8 +567,13 @@ TEST(InterfaceAPI, ContainsDeviceStructMany) {
                 << "  void* device2;\n"
                 << "};\n\n";
 
-  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"}, {},
-                                                 {}, {"device1", "device2"}, 
0);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+  runtime::Module test_module =
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {},
+                       {"device1", "device2"}, 0, input_sizes, output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(device_struct.str()));
@@ -401,22 +587,36 @@ TEST(InterfaceAPI, ContainsDeviceStructSanitised) {
                 << "  void* device_2;\n"
                 << "};\n\n";
 
-  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"}, {},
-                                                 {}, {"device+1", "device+2"}, 
0);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+  runtime::Module test_module =
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {},
+                       {"device+1", "device+2"}, 0, input_sizes, output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(device_struct.str()));
 }
 
 TEST(InterfaceAPI, ContainsDeviceStructClash) {
-  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"}, {},
-                                                 {}, {"device+", "device-"}, 
0);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+  runtime::Module test_module =
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {},
+                       {"device+", "device-"}, 0, input_sizes, output_sizes);
   ASSERT_THROW(test_module->GetSource(), InternalError);
 }
 
 TEST(InterfaceAPI, ContainsWorkspaceSize) {
-  runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 
{}, 765432);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"}, {},
+                                                 {}, {}, 765432, input_sizes, 
output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source,
@@ -441,8 +641,13 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructSingle) {
       << "  void* my_memory_pool;\n"
       << "};\n\n";
 
-  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"},
-                                                 {allocated_pool_info}, {}, 
{}, 0);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+  runtime::Module test_module =
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, 
{allocated_pool_info}, {}, {},
+                       0, input_sizes, output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(workspace_struct.str()));
@@ -474,9 +679,13 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructMany) {
       << "  void* my_memory_pool_2;\n"
       << "};\n\n";
 
-  runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
-                       {allocated_pool_info1, allocated_pool_info2}, {}, {}, 
0);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"},
+                                                 {allocated_pool_info1, 
allocated_pool_info2}, {},
+                                                 {}, 0, input_sizes, 
output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(workspace_struct.str()));
@@ -511,8 +720,13 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructSanitized) {
       << "  void* my_memory_pool_1;\n"
       << "};\n\n";
 
-  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"},
-                                                 {allocated_pool_info}, {}, 
{}, 0);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+  runtime::Module test_module =
+      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, 
{allocated_pool_info}, {}, {},
+                       0, input_sizes, output_sizes);
   std::string header_source = test_module->GetSource();
 
   ASSERT_THAT(header_source, HasSubstr(workspace_struct.str()));
@@ -533,9 +747,13 @@ TEST(InterfaceAPI, ContainsWorkspacePoolStructClash) {
   tir::usmp::AllocatedPoolInfo allocated_pool_info2 =
       tir::usmp::AllocatedPoolInfo(pool_info2, 200000);
 
-  runtime::Module test_module =
-      InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"},
-                       {allocated_pool_info1, allocated_pool_info2}, {}, {}, 
0);
+  Map<String, IntImm> input_sizes;
+  input_sizes.Set("input", IntImm(DataType::Int(32), 0));
+  Map<String, IntImm> output_sizes;
+  output_sizes.Set("output", IntImm(DataType::Int(32), 0));
+  runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", 
{"input"}, {"output"},
+                                                 {allocated_pool_info1, 
allocated_pool_info2}, {},
+                                                 {}, 0, input_sizes, 
output_sizes);
   ASSERT_THROW(test_module->GetSource(), InternalError);
 }
 
diff --git a/tests/micro/zephyr/utils.py b/tests/micro/zephyr/utils.py
index 05b2090944..42419b637f 100644
--- a/tests/micro/zephyr/utils.py
+++ b/tests/micro/zephyr/utils.py
@@ -218,7 +218,16 @@ def generate_project(
                         model_files_path, 
arcname=os.path.relpath(model_files_path, tar_temp_dir)
                     )
                 header_path = generate_c_interface_header(
-                    lowered.libmod_name, ["input_1"], ["Identity"], [], {}, 
[], 0, model_files_path
+                    lowered.libmod_name,
+                    ["input_1"],
+                    ["Identity"],
+                    [],
+                    {},
+                    [],
+                    0,
+                    model_files_path,
+                    {},
+                    {},
                 )
                 tf.add(header_path, arcname=os.path.relpath(header_path, 
tar_temp_dir))
 
diff --git a/tests/python/unittest/test_micro_model_library_format.py 
b/tests/python/unittest/test_micro_model_library_format.py
index 9b957e617a..7ccaf72b1b 100644
--- a/tests/python/unittest/test_micro_model_library_format.py
+++ b/tests/python/unittest/test_micro_model_library_format.py
@@ -208,7 +208,12 @@ def test_export_model_library_format_c(
                 {
                     "constants_size_bytes": json_constants_size_bytes,
                     "device": 1,
+                    "inputs": {
+                        "a": {"dtype": "uint8", "size": 2},
+                        "b": {"dtype": "float32", "size": 8},
+                    },
                     "io_size_bytes": 18,
+                    "outputs": {"output": {"dtype": "float32", "size": 8}},
                     "workspace_size_bytes": 0,
                 }
             ]
@@ -295,7 +300,12 @@ def test_export_model_library_format_llvm():
                 {
                     "constants_size_bytes": 8,
                     "device": 1,
+                    "inputs": {
+                        "a": {"dtype": "uint8", "size": 2},
+                        "b": {"dtype": "float32", "size": 8},
+                    },
                     "io_size_bytes": 18,
+                    "outputs": {"output": {"dtype": "float32", "size": 8}},
                     "workspace_size_bytes": 0,
                 }
             ]
@@ -373,7 +383,13 @@ def test_export_model_library_format_workspace(executor, 
runtime):
             {
                 "constants_size_bytes": 0,
                 "device": 1,
+                "inputs": {
+                    "p0": {"dtype": "int16", "size": 802816},
+                    "p1": {"dtype": "int16", "size": 2304},
+                    "p2": {"dtype": "int32", "size": 512},
+                },
                 "io_size_bytes": 1207040,
+                "outputs": {"output": {"dtype": "uint8", "size": 401408}},
                 "workspace_size_bytes": 2466816,
             }
         ]
@@ -454,24 +470,26 @@ def test_export_byoc_c_module():
         with tf.extractfile("./metadata.json") as f:
             metadata = json.load(f)
         main_md = 
metadata["modules"][factory.libmod_name]["memory"]["functions"]["main"]
-        if platform.architecture()[0] == "64bit":
-            assert main_md == [
-                {
-                    "constants_size_bytes": 0,
-                    "device": 1,
-                    "io_size_bytes": 4800,
-                    "workspace_size_bytes": 1200,
-                }
-            ]
-        else:
-            assert main_md == [
-                {
-                    "constants_size_bytes": 0,
-                    "device": 1,
-                    "io_size_bytes": 4800,
-                    "workspace_size_bytes": 1200,
-                }
-            ]
+        assert main_md == [
+            {
+                "constants_size_bytes": 0,
+                "device": 1,
+                "inputs": {
+                    "w0": {"dtype": "float32", "size": 400},
+                    "w1": {"dtype": "float32", "size": 400},
+                    "w2": {"dtype": "float32", "size": 400},
+                    "w3": {"dtype": "float32", "size": 400},
+                    "w4": {"dtype": "float32", "size": 400},
+                    "w5": {"dtype": "float32", "size": 400},
+                    "w6": {"dtype": "float32", "size": 400},
+                    "w7": {"dtype": "float32", "size": 400},
+                    "x": {"dtype": "float32", "size": 400},
+                },
+                "io_size_bytes": 4800,
+                "outputs": {"output": {"dtype": "float32", "size": 1200}},
+                "workspace_size_bytes": 1200,
+            }
+        ]
 
 
 @tvm.testing.requires_micro
@@ -523,7 +541,12 @@ def test_multiple_relay_modules_graph():
             {
                 "constants_size_bytes": 0,
                 "device": 1,
+                "inputs": {
+                    "data": {"dtype": "int8", "size": 12288},
+                    "weight": {"dtype": "int8", "size": 600},
+                },
                 "io_size_bytes": 143960,
+                "outputs": {"output": {"dtype": "int32", "size": 131072}},
                 "workspace_size_bytes": 158088,
             }
         ]


Reply via email to