This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 698531e6d7 [CodeGenC][Redo] Handle GlobalVar callee as internal 
function call (#15835)
698531e6d7 is described below

commit 698531e6d7e1493b7a73c71137132de87de8aad0
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Wed Oct 18 14:07:43 2023 -0500

    [CodeGenC][Redo] Handle GlobalVar callee as internal function call (#15835)
    
    * [CodeGenC][Redo] Handle GlobalVar callee as internal function call
    
    This reverts commit
    [`e88d0d`](https://github.com/apache/tvm/pull/15725), which
    itself reverted
    [`9ff71f`](https://github.com/apache/tvm/pull/15103) for
    breakages on the metal backend.  Now that the CI contains
    compile-time testing of the metal codegen, the original
    breakage should be identifiable.
    
    * Added codegen metal CI debug print
    
    * Print function decl to the argument stream
    
    * Remove the codegen metal CI debug print-outs
---
 .../arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py  |   8 +-
 .../topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py |  87 +++++++++---
 .../arm_cpu/mprofile/dsp/micro_kernel/max_pool.py  |  13 +-
 .../arm_cpu/mprofile/dsp/micro_kernel/tensordot.py |   7 +-
 .../backend/contrib/cmsisnn/tir_to_runtime.cc      |  28 ++--
 .../contrib/example_target_hooks/tir_to_runtime.cc |  26 +++-
 src/relay/backend/contrib/uma/tir_to_runtime.cc    |  34 +++--
 src/target/opt/build_cuda_on.cc                    |  18 ++-
 src/target/source/codegen_aocl.cc                  |  19 ++-
 src/target/source/codegen_c.cc                     | 153 ++++++++++++++-------
 src/target/source/codegen_c.h                      |  59 +++++++-
 src/target/source/codegen_c_host.cc                |  93 ++++++-------
 src/target/source/codegen_c_host.h                 |   3 +-
 src/target/source/codegen_cuda.cc                  |   4 +-
 src/target/source/codegen_cuda.h                   |   2 +-
 src/target/source/codegen_metal.cc                 |  89 ++++++------
 src/target/source/codegen_metal.h                  |   3 +-
 src/target/source/codegen_opencl.cc                |  24 ++--
 src/target/source/codegen_vhls.cc                  |  34 +++--
 src/target/source/codegen_webgpu.cc                |  79 +++++------
 src/target/source/codegen_webgpu.h                 |   4 +-
 src/target/source/source_module.cc                 |   6 +-
 src/tir/op/op.cc                                   |  26 ++++
 .../relay/aot/test_crt_forward_declarations.py     |   6 +-
 .../topi/python/test_topi_conv2d_tensordot_opts.py |  28 +++-
 .../python/unittest/test_target_codegen_c_host.py  |  48 +++++--
 .../test_tir_transform_inject_ptx_async_copy.py    |   1 +
 27 files changed, 597 insertions(+), 305 deletions(-)

diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py 
b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
index e8e45152aa..3eb32d8fdb 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py
@@ -55,7 +55,7 @@ def intrin_sum(shape, in_dtype, out_dtype, reset=False):
             ib = tvm.tir.ir_builder.create()
             ib.emit(
                 tvm.tir.call_extern(
-                    cc.dtype,
+                    "int32",
                     f"{func_prefix}_{width}_{uniq_id}",
                     aa.access_ptr("r"),
                     cc.access_ptr("w"),
@@ -68,7 +68,7 @@ def intrin_sum(shape, in_dtype, out_dtype, reset=False):
         def _reduce_reset():
             ib = tvm.tir.ir_builder.create()
             ib.emit(
-                tvm.tir.call_extern(cc.dtype, 
f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"))
+                tvm.tir.call_extern("int32", f"{func_prefix}_reset_{uniq_id}", 
cc.access_ptr("w"))
             )
             return ib.get()
 
@@ -113,8 +113,8 @@ extern "C"
 __attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}(
     int16_t *arr,
     int16_t *res16,
-    long arr_offset,
-    int reset) {{
+    int32_t arr_offset,
+    int32_t reset) {{
   int n;
   int32_t *p32;
   int32_t res = reset ? 0 : *res16;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py 
b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
index 929dcc6557..e26e818fbd 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py
@@ -156,9 +156,14 @@ __attribute__((always_inline)) static inline const int8_t 
*read_and_pad(const in
 extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t 
gemm_{M}x{N}_body_rest_{uniq_id}(
-    int K,
+    int32_t K_arg,
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int K = K_arg;
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int k_base = (K / 4) * 4;
   switch ( K % 4 ) {{
   case 1:
@@ -200,7 +205,12 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t 
gemm_{M}x{K}x{N}_body_loop_{uniq_id}(
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
+
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       int32_t sum = 0;
@@ -221,7 +231,11 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t 
gemm_{M}x{K}x{N}_body_{uniq_id}(
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int16_t bb_pad[{bb_pad_size}];
   int32_t retcode = 0;
 
@@ -265,9 +279,14 @@ out:
 extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t 
gemm_{M}x{N}_update_rest_{uniq_id}(
-    int K,
+    int32_t K_arg,
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int K = K_arg;
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int k_base = (K / 4) * 4;
   switch ( K % 4 ) {{
   case 1:
@@ -309,7 +328,11 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t 
gemm_{M}x{K}x{N}_update_loop_{uniq_id}(
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       int32_t sum = 0;
@@ -327,7 +350,11 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t 
gemm_{M}x{K}x{N}_update_{uniq_id}(
     int8_t *aa, int8_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int16_t bb_pad[{bb_pad_size}];
   int32_t retcode = 0;
 
@@ -368,9 +395,14 @@ out:
 extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t 
gemm16_{M}x{N}_body_rest_{uniq_id}(
-    int K,
+    int32_t K_arg,
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int K = K_arg;
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int k_base = (K / 2) * 2;
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
@@ -387,7 +419,11 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t 
gemm16_{M}x{K}x{N}_body_loop_{uniq_id}(
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       int32_t sum = 0;
@@ -408,7 +444,11 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t 
gemm16_{M}x{K}x{N}_body_{uniq_id}(
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int32_t retcode = 0;
 
   if ( {M} < 2 && {N} < 2 ) {{
@@ -450,9 +490,14 @@ out:
 extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t 
gemm16_{M}x{N}_update_rest_{uniq_id}(
-    int K,
+    int32_t K_arg,
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int K = K_arg;
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int k_base = (K / 2) * 2;
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
@@ -469,7 +514,11 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t 
gemm16_{M}x{K}x{N}_update_loop_{uniq_id}(
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       int32_t sum = 0;
@@ -487,7 +536,11 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t 
gemm16_{M}x{K}x{N}_update_{uniq_id}(
     int16_t *aa, int16_t *bb, int32_t *cc,
-    int A_stride, int B_stride, int C_stride) {{
+    int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{
+  int A_stride = A_stride_arg;
+  int B_stride = B_stride_arg;
+  int C_stride = C_stride_arg;
+
   int32_t retcode = 0;
 
   if ( {M} < 2 && {N} < 2 ) {{
@@ -520,7 +573,7 @@ out:
 #ifdef __cplusplus
 extern "C"
 #endif
-__attribute__((always_inline)) static inline int32_t 
gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{
+__attribute__((always_inline)) static inline int32_t 
gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int32_t C_stride) {{
   for (int i = 0; i < {M}; i++) {{
     for (int j = 0; j < {N}; j++) {{
       cc[i*C_stride + j] = 0;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py 
b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
index 66d712a4a0..cfed417c9f 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py
@@ -46,7 +46,7 @@ def intrin_max(shape, in_dtype, out_dtype):
             ib = tvm.tir.ir_builder.create()
             ib.emit(
                 tvm.tir.call_extern(
-                    cc.dtype,
+                    "int32",
                     f"{func_prefix}_{uniq_id}",
                     aa.access_ptr("r"),
                     cc.access_ptr("w"),
@@ -59,7 +59,7 @@ def intrin_max(shape, in_dtype, out_dtype):
             ib = tvm.tir.ir_builder.create()
             ib.emit(
                 tvm.tir.call_extern(
-                    cc.dtype, f"{func_prefix}_reset_{uniq_id}", 
cc.access_ptr("w"), cc.strides[0]
+                    "int32", f"{func_prefix}_reset_{uniq_id}", 
cc.access_ptr("w"), cc.strides[0]
                 )
             )
             return ib.get()
@@ -96,7 +96,7 @@ extern "C"
 #endif
 __attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}(
     int8_t *res,
-    int N) {{
+    int32_t N) {{
   memset(res, (int8_t)-128, N * sizeof(*res));
   return 0;
 }}
@@ -107,7 +107,9 @@ extern "C"
 __attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}(
     int8_t *arg,
     int8_t *res,
-    int N) {{
+    int32_t N_arg) {{
+  int N = N_arg;
+
   for ( int i = 0; i < N; ++ i )
     if ( arg[i] > res[i] )
       res[i] = arg[i];
@@ -120,7 +122,8 @@ extern "C"
 __attribute__((always_inline)) static inline int32_t max8_{uniq_id}(
     int8_t *arg,
     int8_t *res,
-    int N) {{
+    int32_t N_arg) {{
+  int N = N_arg;
   int32_t *parg32, *pres32;
   int una_arg = (int32_t)arg & 0x3, una_res = (int32_t)res & 0x3;
   int32_t retcode = 0;
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py 
b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
index d2a8f1ef69..af3b23e01d 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py
@@ -390,8 +390,13 @@ def tensordot_int16_impl(
         #define {function_name.upper()}_EXISTS
         #include <arm_acle.h>
         __attribute__((always_inline)) static inline int32_t {function_name}(
-            int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, 
int32_t *scale
+            int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+            int32_t *bias, int32_t *scale
         ) {{
+          int32_t *output = output_arg;
+          int32_t *tensor = tensor_arg;
+          int32_t *kernel = kernel_arg;
+
           {_init_biased_accumulators(num_outputs)}
 
           {insert_lines(load_tensor_lines)}
diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc 
b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
index 186fa30f20..6febfe3486 100644
--- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc
@@ -46,13 +46,6 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost {
     CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, 
target_str, devices);
   }
 
-  /*!
-   * \brief Emit code that offloads a subgraph to the Cortex-M
-   *
-   * \return string of code that offloads a subgraph to the Cortex-M
-   */
-  void AddFunction(const PrimFunc& prim_func) { 
CodeGenC::AddFunction(prim_func); }
-
  private:
   /*!  * \brief Enable storing the last error */
   bool debug_last_error;
@@ -575,11 +568,11 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) 
{
   bool emit_fwd_func_decl = false;
   bool debug_last_error = GetCompilerAttrs()->debug_last_error;
   CodeGenCMSISNN codegen;
-  Array<String> function_names;
   codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), 
debug_last_error);
-  std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
-  for (auto kv : mod->functions) {
-    funcs.push_back(kv);
+
+  std::vector<std::pair<tvm::GlobalVar, tvm::PrimFunc>> funcs;
+  for (auto [gvar, base_func] : mod->functions) {
+    funcs.push_back({gvar, Downcast<PrimFunc>(base_func)});
   }
 
   std::sort(funcs.begin(), funcs.end(),
@@ -594,13 +587,16 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) 
{
               return name_hint_a < name_hint_b;
             });
 
-  for (auto kv : funcs) {
-    auto prim_func = Downcast<PrimFunc>(kv.second);
-    auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-    function_names.push_back(global_symbol.value());
-    codegen.AddFunction(prim_func);
+  for (auto [gvar, prim_func] : funcs) {
+    codegen.AddFunction(gvar, prim_func);
   }
   std::string code = codegen.Finish();
+
+  Array<String> function_names;
+  for (auto [gvar, prim_func] : funcs) {
+    function_names.push_back(codegen.GetFunctionName(gvar));
+  }
+
   return codegen::CSourceModuleCreate(code, "c", function_names);
 }
 
diff --git a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc 
b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
index 0db8d06c31..6f09e0a0c3 100644
--- a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc
@@ -49,16 +49,30 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
   bool emit_asserts = false;
   bool emit_fwd_func_decl = false;
   CodeGenExampleTargetHook codegen;
-  Array<String> function_names;
+
   std::unordered_set<std::string> devices;
   codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), 
devices);
-  for (auto kv : mod->functions) {
-    auto prim_func = Downcast<PrimFunc>(kv.second);
-    auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-    function_names.push_back(global_symbol.value());
-    codegen.AddFunction(prim_func);
+
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    functions.Set(gvar, prim_func);
+  }
+
+  for (auto [gvar, prim_func] : functions) {
+    codegen.DeclareFunction(gvar, prim_func);
+  }
+  for (auto [gvar, prim_func] : functions) {
+    codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
   }
+
   std::string code = codegen.Finish();
+
+  Array<String> function_names;
+  for (auto [gvar, prim_func] : functions) {
+    function_names.push_back(codegen.GetFunctionName(gvar));
+  }
+
   return codegen::CSourceModuleCreate(code, "c", function_names);
 }
 
diff --git a/src/relay/backend/contrib/uma/tir_to_runtime.cc 
b/src/relay/backend/contrib/uma/tir_to_runtime.cc
index 3b58fda54b..487e247f5d 100644
--- a/src/relay/backend/contrib/uma/tir_to_runtime.cc
+++ b/src/relay/backend/contrib/uma/tir_to_runtime.cc
@@ -49,13 +49,6 @@ class UMACodegen : public codegen::CodeGenCHost {
     CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, 
target_str_, devices);
   }
 
-  /*!
-   * \brief Emit code that offloads a subgraph to the UMA target
-   *
-   * \return string of code that offloads a subgraph to the UMA target
-   */
-  void AddFunction(const PrimFunc& prim_func) { 
CodeGenC::AddFunction(prim_func); }
-
  private:
   String target_str_;
 };
@@ -63,17 +56,30 @@ class UMACodegen : public codegen::CodeGenCHost {
 runtime::Module TIRToRuntime(IRModule mod, Target target) {
   bool output_ssa = false;
   bool emit_asserts = false;
-  bool emit_fwd_func_decl = false;
+  bool emit_fwd_func_decl = true;
   UMACodegen codegen(target->kind->name);
-  Array<String> function_names;
   codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl);
-  for (auto kv : mod->functions) {
-    auto prim_func = Downcast<PrimFunc>(kv.second);
-    auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-    function_names.push_back(global_symbol.value());
-    codegen.AddFunction(prim_func);
+
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    functions.Set(gvar, prim_func);
+  }
+
+  for (auto [gvar, prim_func] : functions) {
+    codegen.DeclareFunction(gvar, prim_func);
+  }
+  for (auto [gvar, prim_func] : functions) {
+    codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl);
   }
+
   std::string code = codegen.Finish();
+
+  Array<String> function_names;
+  for (auto [gvar, prim_func] : functions) {
+    function_names.push_back(codegen.GetFunctionName(gvar));
+  }
+
   return codegen::CSourceModuleCreate(code, "c", function_names);
 }
 
diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc
index 1c0b5094ef..e0f53e3509 100644
--- a/src/target/opt/build_cuda_on.cc
+++ b/src/target/opt/build_cuda_on.cc
@@ -131,13 +131,21 @@ runtime::Module BuildCUDA(IRModule mod, Target target) {
   CodeGenCUDA cg;
   cg.Init(output_ssa);
 
-  for (auto kv : mod->functions) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only 
take PrimFunc";
-    auto f = Downcast<PrimFunc>(kv.second);
-    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only 
take PrimFunc";
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
     ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
         << "CodeGenCUDA: expect calling_conv equals 
CallingConv::kDeviceKernelLaunch";
-    cg.AddFunction(f);
+    functions.Set(gvar, prim_func);
+  }
+
+  for (auto [gvar, prim_func] : functions) {
+    cg.DeclareFunction(gvar, prim_func);
+  }
+  for (auto [gvar, prim_func] : functions) {
+    cg.AddFunction(gvar, prim_func);
   }
 
   std::string code = cg.Finish();
diff --git a/src/target/source/codegen_aocl.cc 
b/src/target/source/codegen_aocl.cc
index 700d85b4cc..dc3ba08751 100644
--- a/src/target/source/codegen_aocl.cc
+++ b/src/target/source/codegen_aocl.cc
@@ -40,13 +40,22 @@ runtime::Module BuildAOCL(IRModule mod, Target target, bool 
emulation) {
   CodeGenOpenCL cg;
   cg.Init(output_ssa);
 
-  for (auto kv : mod->functions) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenOpenCL: Can only 
take PrimFunc";
-    auto f = Downcast<PrimFunc>(kv.second);
-    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodegenOpenCL: Can only 
take PrimFunc";
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
     ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
         << "CodegenOpenCL: expect calling_conv equals 
CallingConv::kDeviceKernelLaunch";
-    cg.AddFunction(f);
+    functions.Set(gvar, prim_func);
+  }
+
+  for (auto [gvar, prim_func] : functions) {
+    cg.DeclareFunction(gvar, prim_func);
+  }
+
+  for (auto [gvar, prim_func] : functions) {
+    cg.AddFunction(gvar, prim_func);
   }
 
   std::string code = cg.Finish();
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index a7cc320562..187bdc74fe 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -42,6 +42,7 @@ void CodeGenC::InitFuncState(const PrimFunc& f) {
   alloc_storage_scope_.clear();
   handle_data_type_.clear();
   CodeGenSourceBase::ClearFuncState();
+  ReserveKeywordsAsUnique();
 }
 
 void CodeGenC::ReserveKeywordsAsUnique() {
@@ -75,51 +76,92 @@ void CodeGenC::ReserveKeywordsAsUnique() {
   name_supply_->ReserveName("return");
 }
 
-void CodeGenC::AddFunction(const PrimFunc& f) {
-  // clear previous generated state.
-  this->InitFuncState(f);
-  // reserve keywords
-  ReserveKeywordsAsUnique();
+void CodeGenC::PrintFunctionSignature(const String& function_name, const 
PrimFunc& func,
+                                      std::ostream& os) {
+  PrintFuncPrefix(os);
+  PrintType(func->ret_type, os);
+  PrintExtraAttrs(func, os);
+  os << " " << function_name << "(";
+  for (size_t i = 0; i < func->params.size(); ++i) {
+    tir::Var v = func->params[i];
 
-  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-  ICHECK(global_symbol.defined())
-      << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
-  bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias);
-
-  this->PrintFuncPrefix(stream);
-  PrintType(f->ret_type, stream);
-  this->PrintExtraAttrs(f);
-  this->stream << " " << static_cast<std::string>(global_symbol.value()) << 
"(";
-
-  for (size_t i = 0; i < f->params.size(); ++i) {
-    tir::Var v = f->params[i];
-    std::string vid = AllocVarID(v.get());
-    if (i != 0) stream << ", ";
-    if (v.dtype().is_handle()) {
-      auto it = alloc_storage_scope_.find(v.get());
-      if (it != alloc_storage_scope_.end()) {
-        PrintStorageScope(it->second, stream);
-      }
+    if (i > 0) {
+      os << ", ";
+    }
 
-      PrintType(GetType(v), stream);
-      // Register handle data type
-      // TODO(tvm-team): consider simply keep type info in the
-      // type annotation(via a normalizing rewriting).
-      if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) {
-        if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
-          RegisterHandleType(v.get(), prim->dtype);
-        }
-      }
+    if (auto it = alloc_storage_scope_.find(v.get()); it != 
alloc_storage_scope_.end()) {
+      PrintStorageScope(it->second, os);
+    }
 
-      if (no_alias) {
-        PrintRestrict(v, stream);
+    PrintType(GetType(v), os);
+
+    bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias);
+    bool is_handle = v.dtype().is_handle();
+    if (no_alias && is_handle) {
+      PrintRestrict(v, os);
+    }
+
+    os << " " << AllocVarID(v.get());
+  }
+  os << ")";
+
+  // Register handle data type
+  // TODO(tvm-team): consider simply keep type info in the
+  // type annotation(via a normalizing rewriting).
+  for (const auto& param : func->params) {
+    if (auto* ptr = param->type_annotation.as<PointerTypeNode>()) {
+      if (auto* prim = ptr->element_type.as<PrimTypeNode>()) {
+        RegisterHandleType(param.get(), prim->dtype);
       }
-    } else {
-      PrintType(GetType(v), stream);
     }
-    stream << ' ' << vid;
   }
-  stream << ") {\n";
+}
+
+void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) {
+  if (internal_functions_.count(gvar)) {
+    return;
+  }
+
+  auto function_name = [&]() -> String {
+    if (auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
+      auto name = global_symbol.value();
+      ICHECK(!func_name_supply_->ContainsName(name))
+          << "Function " << gvar << " must use global symbol " << name
+          << ", but this name has already been used.";
+      func_name_supply_->ReserveName(name);
+      return name;
+    } else {
+      func_name_supply_->ReserveName(gvar->name_hint);
+      return gvar->name_hint;
+    }
+  }();
+
+  internal_functions_.insert({gvar, function_name});
+
+  InitFuncState(func);
+  PrintFunctionSignature(function_name, func, fwd_decl_stream);
+  fwd_decl_stream << ";\n";
+}
+
+String CodeGenC::GetFunctionName(const GlobalVar& gvar) {
+  auto it = internal_functions_.find(gvar);
+  ICHECK(it != internal_functions_.end())
+      << "Attempted to find name of " << gvar
+      << ", but no function with this GlobalVar has been declared";
+  return it->second;
+}
+
+void CodeGenC::AddFunction(const GlobalVar& gvar, const PrimFunc& f) {
+  // If the function has already been forward-declared, this is a
+  // no-op.
+  DeclareFunction(gvar, f);
+  auto function_name = GetFunctionName(gvar);
+
+  // clear previous generated state.
+  InitFuncState(f);
+
+  PrintFunctionSignature(function_name, f, stream);
+  stream << " {\n";
   this->PreFunctionBody(f);
   int func_scope = this->BeginScope();
   this->PrintStmt(f->body);
@@ -130,9 +172,15 @@ void CodeGenC::AddFunction(const PrimFunc& f) {
 
 void CodeGenC::PrintFuncPrefix(std::ostream& os) {}
 
-void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {}
+void CodeGenC::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {}
 
-std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); }
+std::string CodeGenC::Finish() {
+  std::ostringstream code;
+  code << decl_stream.str();
+  code << fwd_decl_stream.str();
+  code << stream.str();
+  return code.str();
+}
 
 void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) {  // NOLINT(*)
   if (print_ssa_form_) {
@@ -542,12 +590,17 @@ void CodeGenC::VisitExpr_(const CallNode* op, 
std::ostream& os) {  // NOLINT(*)
       ICHECK_GE(op->args.size(), 1U);
       auto func = Downcast<StringImm>(op->args[0]);
       this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, 
op->args, true, os);
-      Array<Type> arg_types;
-      for (size_t i = 1; i < op->args.size(); i++) {
-        arg_types.push_back(GetType(op->args[i]));
+
+      // If the call_extern refers to an function within the IRModule, then
+      // the forward declaration is already provided from DeclareFunction.
+      if (!func_name_supply_->ContainsName(func->value)) {
+        Array<Type> arg_types;
+        for (size_t i = 1; i < op->args.size(); i++) {
+          arg_types.push_back(GetType(op->args[i]));
+        }
+        Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
+        this->GenerateForwardFunctionDeclarations(func->value, arg_types, 
ret_type);
       }
-      Type ret_type = GetTypeFromRuntimeDataType(op->dtype);
-      this->GenerateForwardFunctionDeclarations(func->value, arg_types, 
ret_type);
     } else if (op_attr_global_symbol_.count(call_op)) {
       // call extern if the op itself have a global symbol.
       this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), 
op_attr_global_symbol_[call_op],
@@ -615,9 +668,13 @@ void CodeGenC::VisitExpr_(const CallNode* op, 
std::ostream& os) {  // NOLINT(*)
     } else {
       LOG(FATAL) << "Unresolved call " << op->op;
     }
+  } else if (auto opt = op->op.as<GlobalVar>()) {
+    auto gvar = opt.value();
+    auto callee_name = GetFunctionName(gvar);
+    PrintCallExtern(GetType(GetRef<PrimExpr>(op)), callee_name, op->args, 
false, os);
   } else {
-    ICHECK(op->op.as<GlobalVarNode>());
-    LOG(FATAL) << "Do not yet support cross function call";
+    LOG(FATAL) << "CodeGenC: Unknown operation " << op->op << " is neither a 
recognized built-in, "
+               << "nor a GlobalVar reference to another function in the 
IRModule";
   }
 }
 
diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h
index 93f9ea519c..2921a56ef3 100644
--- a/src/target/source/codegen_c.h
+++ b/src/target/source/codegen_c.h
@@ -65,12 +65,33 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, 
std::ostream&)>,
    * \param output_ssa Whether output SSA.
    */
   void Init(bool output_ssa);
+
   /*!
-   * \brief Add the function to the generated module.
-   * \param f The function to be compiled.
+   * \brief Add the function declaration to the generated module,
+   * without defining it.
+   *
+   * \param gvar The GlobalVar representing the function.
+   * \param func The function to be compiled.
    * \param whether to append return 0 in the end.
    */
-  void AddFunction(const PrimFunc& f);
+  virtual void DeclareFunction(const GlobalVar& gvar, const PrimFunc& func);
+
+  /*!
+   * \brief Add the function to the generated module, including its
+   * declaration and definition.
+   *
+   * \param gvar The GlobalVar representing the function.
+   * \param func The function to be compiled.
+   */
+  virtual void AddFunction(const GlobalVar& gvar, const PrimFunc& func);
+
+  /*!
+   * \brief Get the name of a declared function
+   * \param gvar The GlobalVar of the function
+   * \returns The string name of the function
+   */
+  String GetFunctionName(const GlobalVar& gvar);
+
   /*!
    * \brief Finalize the compilation and return the code.
    * \return The code.
@@ -96,7 +117,23 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, 
std::ostream&)>,
     PrintExpr(n, os);
     return os.str();
   }
+
   // The following parts are overloadable print operations.
+
+  /*! \brief Print the function signature before the argument list
+   *
+   * The default implementation delegates out to PrintFuncPrefix and
+   * PrintExtraAttrs.
+   *
+   * \param function_name The name of the function
+   *
+   * \param func The function whose signature should be printed
+   *
+   * \param os The output stream
+   */
+  virtual void PrintFunctionSignature(const String& function_name, const 
PrimFunc& func,
+                                      std::ostream& os);
+
   /*!
    * \brief Print the function header before the argument list
    * \param os The output stream
@@ -109,7 +146,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, 
std::ostream&)>,
    *
    *  Example: __launch_bounds__(256) for CUDA functions
    */
-  virtual void PrintExtraAttrs(const PrimFunc& f);
+  virtual void PrintExtraAttrs(const PrimFunc& f, std::ostream& os);  // 
NOLINT(*)
   /*!
    * \brief Insert statement before function body.
    * \param f The function to be compiled.
@@ -284,10 +321,24 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, 
std::ostream&)>,
  private:
   /*! \brief set of volatile buf access */
   std::unordered_set<const VarNode*> volatile_buf_;
+
   // deep comparison of PrimExpr
   ExprDeepEqual deep_equal_;
+
   // binding of let variables. Enables duplicate var defs that map to same 
value
   std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> 
let_binding_;
+
+  /* \brief Map of GlobalVar to their symbol.
+   *
+   * For externally-exposed functions, this is given by the
+   * tvm::attr::kTarget attribute of the PrimFunc.  For internal
+   * functions, this is the name of the function's GlobalVar, possibly
+   * altered to prevent duplicate names.
+   */
+  std::unordered_map<GlobalVar, String, ObjectPtrHash, ObjectPtrEqual> 
internal_functions_;
+
+  /* \brief Name supply to generate unique function names */
+  NameSupply func_name_supply_{""};
 };
 
 }  // namespace codegen
diff --git a/src/target/source/codegen_c_host.cc 
b/src/target/source/codegen_c_host.cc
index 3255e11c5d..caef43e8af 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -75,19 +75,24 @@ void CodeGenCHost::InitGlobalContext() {
 
 void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << 
module_name_ << " = NULL;\n"; }
 
-void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) {
-  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-  ICHECK(global_symbol.defined())
-      << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute";
-  function_names_.push_back(global_symbol.value());
+void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func,
+                               bool emit_fwd_func_decl) {
+  auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+  if (global_symbol) {
+    function_names_.push_back(global_symbol.value());
+  }
 
   emit_fwd_func_decl_ = emit_fwd_func_decl;
-  CodeGenC::AddFunction(f);
-  if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
+  CodeGenC::AddFunction(gvar, func);
+  if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
+    ICHECK(global_symbol.defined())
+        << "CodeGenCHost: The entry func must have the global_symbol 
attribute, "
+        << "but function " << gvar << " only has attributes " << func->attrs;
+
     function_names_.push_back(runtime::symbol::tvm_module_main);
     stream << "// CodegenC: NOTE: Auto-generated entry function\n";
     PrintFuncPrefix(stream);
-    PrintType(f->ret_type, stream);
+    PrintType(func->ret_type, stream);
     stream << " " << tvm::runtime::symbol::tvm_module_main
            << "(void* args, int* arg_type_ids, int num_args, void* 
out_ret_value, "
            << "int* out_ret_tcode, void* resource_handle) {\n";
@@ -128,15 +133,6 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) {  // 
NOLINT(*)
      << "TVM_DLL ";
 }
 
-std::string CodeGenCHost::Finish() {  // NOLINT(*)
-  std::string ret = decl_stream.str();
-  if (emit_fwd_func_decl_) {
-    ret += fwd_decl_stream.str();
-  }
-  ret += stream.str();
-  return ret;
-}
-
 void CodeGenCHost::PrintType(DataType t, std::ostream& os) {  // NOLINT(*)
   int lanes = t.lanes();
   if (t.is_handle()) {
@@ -437,42 +433,38 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
   CodeGenCHost cg;
   cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), 
devices);
   
cg.SetConstantsByteAlignment(target->GetAttr<Integer>("constants-byte-alignment").value_or(16));
-  PrimFunc aot_executor_fn;
-
-  std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs;
-  for (auto kv : mod->functions) {
-    // Make sure that the executor function is the last one to be code 
generated so that all the
-    // symbols are available to __tvm_main__
-    auto fun_name = std::string(kv.first->name_hint);
-    bool is_aot_executor_fn = kv.second->GetAttr<Bool>("runner_function", 
Bool(false)).value();
-
-    if (is_aot_executor_fn) {
-      aot_executor_fn = Downcast<PrimFunc>(kv.second);
-      continue;
-    }
-    funcs.push_back(kv);
+
+  auto is_aot_executor_fn = [](const PrimFunc& func) -> bool {
+    return func->GetAttr<Bool>("runner_function", Bool(false)).value();
+  };
+
+  std::vector<std::pair<GlobalVar, PrimFunc>> funcs;
+  for (auto [gvar, base_func] : mod->functions) {
+    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only 
take PrimFunc";
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    funcs.push_back({gvar, prim_func});
   }
 
   // Sort functions
-  std::sort(funcs.begin(), funcs.end(),
-            [](std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_a,
-               std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_b) {
-              std::string name_hint_a = kv_a.first->name_hint;
-              std::string name_hint_b = kv_b.first->name_hint;
-              return name_hint_a < name_hint_b;
-            });
-
-  // Add all functions except __tvm_main__
-  for (auto& kv : funcs) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only 
take PrimFunc";
-    auto f = Downcast<PrimFunc>(kv.second);
-    cg.AddFunction(f);
+  auto sort_key = [&is_aot_executor_fn](const auto& kv) {
+    return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint};
+  };
+  std::sort(funcs.begin(), funcs.end(), [&sort_key](const auto& kv_a, const 
auto& kv_b) {
+    return sort_key(kv_a) < sort_key(kv_b);
+  });
+
+  // Declare all functions first.  This ensures that all functions,
+  // including the __tvm_main__ used in AOT, have access to forward
+  // declarations of other functions in the IRModule.
+  for (const auto& [gvar, prim_func] : funcs) {
+    cg.DeclareFunction(gvar, prim_func);
   }
 
-  // Add __tvm_main__
-  if (aot_executor_fn.defined()) {
-    emit_fwd_func_decl = true;
-    cg.AddFunction(aot_executor_fn, emit_fwd_func_decl);
+  // Codegen all functions.  Passing emit_fwd_func_decl=true adds a
+  // forward declaration for any `builtin::call_extern`, based on the
+  // arguments provided to it.
+  for (const auto& [gvar, prim_func] : funcs) {
+    cg.AddFunction(gvar, prim_func, emit_fwd_func_decl);
   }
 
   // NOTE: it's possible that kRuntime attr is not attached when the mod was 
built with tvm.build().
@@ -484,7 +476,10 @@ runtime::Module BuildCHost(IRModule mod, Target target) {
   } else {
     runtime = relay::Runtime::Create("cpp", {});
   }
-  if (aot_executor_fn.defined() && runtime->name == relay::kTvmRuntimeCpp) {
+
+  bool has_aot_executor_fn = std::any_of(
+      funcs.begin(), funcs.end(), [&](const auto& kv) { return 
is_aot_executor_fn(kv.second); });
+  if (has_aot_executor_fn && runtime->name == relay::kTvmRuntimeCpp) {
     cg.InitGlobalContext();
   }
 
diff --git a/src/target/source/codegen_c_host.h 
b/src/target/source/codegen_c_host.h
index 694104afc0..aeba685f74 100644
--- a/src/target/source/codegen_c_host.h
+++ b/src/target/source/codegen_c_host.h
@@ -44,8 +44,7 @@ class CodeGenCHost : public CodeGenC {
             const std::unordered_set<std::string>& devices);
 
   void InitGlobalContext();
-  void AddFunction(const PrimFunc& f, bool emit_fwd_func_decl = false);
-  std::string Finish() final;
+  void AddFunction(const GlobalVar& gvar, const PrimFunc& f, bool 
emit_fwd_func_decl = false);
   /*!
    * \brief Add functions from the (unordered) range to the current module in 
a deterministic
    * order. This helps with debugging.
diff --git a/src/target/source/codegen_cuda.cc 
b/src/target/source/codegen_cuda.cc
index a91f8b0164..7639ce6065 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -75,7 +75,7 @@ class ThreadIdxExtractor : public tir::StmtVisitor {
   PrimExpr threadIdx_z_ext = Integer(1);
 };
 
-void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) {
+void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {
   ThreadIdxExtractor extractor;
   extractor(f->body);
   arith::Analyzer analyzer;
@@ -86,7 +86,7 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) {
       // unable to extract the number of threads per block, hence directly 
return
       return;
     }
-    stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
+    os << " __launch_bounds__(" << threadIdx_ext_int->value << ")";
   }
 }
 
diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h
index 3ec0c3bc2d..bc7b34b500 100644
--- a/src/target/source/codegen_cuda.h
+++ b/src/target/source/codegen_cuda.h
@@ -47,7 +47,7 @@ class CodeGenCUDA final : public CodeGenC {
   }
   // override behavior
   void PrintFuncPrefix(std::ostream& os) final;
-  void PrintExtraAttrs(const PrimFunc& f) final;
+  void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final;  // 
NOLINT(*)
   void VisitStmt_(const ForNode* op) final;
   void PrintStorageSync(const CallNode* op) final;
   void PrintStorageScope(const std::string& scope, std::ostream& os) final;  
// NOLINT(*)
diff --git a/src/target/source/codegen_metal.cc 
b/src/target/source/codegen_metal.cc
index b8c30691e2..ebb7566489 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -36,6 +36,8 @@ namespace codegen {
 
 void CodeGenMetal::InitFuncState(const PrimFunc& f) {
   CodeGenC::InitFuncState(f);
+  // skip the first underscore, so SSA variable starts from _1
+  name_supply_->FreshName("v_");
   // analyze the data;
   for (Var arg : f->params) {
     if (arg.dtype().is_handle()) {
@@ -52,37 +54,33 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target) 
{
               << "};\n\n";
 }
 
-void CodeGenMetal::AddFunction(const PrimFunc& f) {
-  // clear previous generated state.
-  this->InitFuncState(f);
-  // skip the first underscore, so SSA variable starts from _1
-  name_supply_->FreshName("v_");
-
+void CodeGenMetal::PrintFunctionSignature(const String& function_name, const 
PrimFunc& func,
+                                          std::ostream& os) {
   // add to alloc buffer type.
-  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+  auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
   ICHECK(global_symbol.defined())
       << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
 
   // Function header.
-  this->stream << "kernel void " << 
static_cast<std::string>(global_symbol.value()) << "(";
+  os << "kernel void " << static_cast<std::string>(global_symbol.value()) << 
"(";
 
   // Buffer arguments
   size_t num_buffer = 0;
   size_t limit = 
target_->GetAttr<Integer>("max_function_args").value().IntValue();
-  if (f->params.size() > limit) {
+  if (func->params.size() > limit) {
     LOG(WARNING) << "Probably you won't be able to execute your kernel due to 
high number of "
                     "buffers in the kernel";
   }
-  for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) {
-    Var v = f->params[i];
+  for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) {
+    Var v = func->params[i];
     if (!v.dtype().is_handle()) break;
-    stream << "  ";
+    os << "  ";
     std::string vid = AllocVarID(v.get());
     auto it = alloc_storage_scope_.find(v.get());
     if (it != alloc_storage_scope_.end()) {
-      PrintStorageScope(it->second, stream);
+      PrintStorageScope(it->second, os);
     }
-    PrintType(GetType(v), stream);
+    PrintType(GetType(v), os);
     // Register handle data type
     // TODO(tvm-team): consider simply keep type info in the
     // type annotation(via a normalizing rewriting).
@@ -91,19 +89,18 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
         RegisterHandleType(v.get(), prim->dtype);
       }
     }
-    stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
+    os << ' ' << vid << " [[ buffer(" << i << ") ]],\n";
   }
   // Setup normal arguments.
-  size_t nargs = f->params.size() - num_buffer;
+  size_t nargs = func->params.size() - num_buffer;
   std::string varg = name_supply_->FreshName("arg");
   if (nargs != 0) {
     std::string arg_buf_type = static_cast<std::string>(global_symbol.value()) 
+ "_args_t";
-    stream << "  constant " << arg_buf_type << "& " << varg << " [[ buffer(" 
<< num_buffer
-           << ") ]],\n";
+    os << "  constant " << arg_buf_type << "& " << varg << " [[ buffer(" << 
num_buffer << ") ]],\n";
     // declare the struct
     decl_stream << "struct " << arg_buf_type << " {\n";
-    for (size_t i = num_buffer; i < f->params.size(); ++i) {
-      Var v = f->params[i];
+    for (size_t i = num_buffer; i < func->params.size(); ++i) {
+      Var v = func->params[i];
       ICHECK(!v.dtype().is_handle());
       std::string vid = AllocVarID(v.get());
       std::ostringstream vref;
@@ -131,7 +128,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
   ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
   ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
   int work_dim = 0;
-  auto launch_params = 
f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams).value();
+  auto launch_params = 
func->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams).value();
   for (const auto& tag : launch_params) {
     if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) {
       runtime::ThreadScope scope = runtime::ThreadScope::Create(tag);
@@ -141,22 +138,16 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) {
 
   if (work_dim != 0) {
     // use ushort by default for now
-    stream << "  ";
-    PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
-    stream << " blockIdx [[threadgroup_position_in_grid]],\n";
-    stream << "  ";
-    PrintType(DataType::UInt(thread_index_bits_, work_dim), stream);
-    stream << " threadIdx [[thread_position_in_threadgroup]]\n";
+    os << "  ";
+    PrintType(DataType::UInt(thread_index_bits_, work_dim), os);
+    os << " blockIdx [[threadgroup_position_in_grid]],\n";
+    os << "  ";
+    PrintType(DataType::UInt(thread_index_bits_, work_dim), os);
+    os << " threadIdx [[thread_position_in_threadgroup]]\n";
   }
   thread_work_dim_ = work_dim;
 
-  // the function scope.
-  stream << ") {\n";
-  int func_scope = this->BeginScope();
-  this->PrintStmt(f->body);
-  this->EndScope(func_scope);
-  this->PrintIndent();
-  this->stream << "}\n\n";
+  os << ")";
 }
 
 void CodeGenMetal::BindThreadIndex(const IterVar& iv) {
@@ -342,27 +333,33 @@ runtime::Module BuildMetal(IRModule mod, Target target) {
   const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile");
   std::string fmt = fmetal_compile ? "metallib" : "metal";
 
-  for (auto kv : mod->functions) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only 
take PrimFunc";
-    auto global_symbol = kv.second->GetAttr<String>(tvm::attr::kGlobalSymbol);
-    ICHECK(global_symbol.defined());
-    std::string func_name = global_symbol.value();
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only 
take PrimFunc";
+    auto calling_conv = base_func->GetAttr<Integer>(tvm::attr::kCallingConv);
+    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
+        << "CodeGenMetal: expect calling_conv equals 
CallingConv::kDeviceKernelLaunch";
+
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    functions.Set(gvar, prim_func);
+  }
 
-    source_maker << "// Function: " << func_name << "\n";
+  for (auto [gvar, prim_func] : functions) {
+    source_maker << "// Function: " << gvar->name_hint << "\n";
     CodeGenMetal cg(target);
     cg.Init(output_ssa);
-    auto f = Downcast<PrimFunc>(kv.second);
-    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
-    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
-        << "CodeGenMetal: expect calling_conv equals 
CallingConv::kDeviceKernelLaunch";
 
-    cg.AddFunction(f);
+    for (auto [other_gvar, other_prim_func] : functions) {
+      cg.DeclareFunction(other_gvar, other_prim_func);
+    }
+    cg.AddFunction(gvar, prim_func);
+
     std::string fsource = cg.Finish();
     source_maker << fsource << "\n";
     if (fmetal_compile) {
       fsource = (*fmetal_compile)(fsource, target).operator std::string();
     }
-    smap[func_name] = fsource;
+    smap[cg.GetFunctionName(gvar)] = fsource;
   }
 
   return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, 
source_maker.str());
diff --git a/src/target/source/codegen_metal.h 
b/src/target/source/codegen_metal.h
index 36be10d163..26c991e60d 100644
--- a/src/target/source/codegen_metal.h
+++ b/src/target/source/codegen_metal.h
@@ -38,7 +38,8 @@ class CodeGenMetal final : public CodeGenC {
   explicit CodeGenMetal(Target target);
   // override print thread tag.
   void PrintArgUnionDecl();
-  void AddFunction(const PrimFunc& f);  // NOLINT(*)
+  void PrintFunctionSignature(const String& function_name, const PrimFunc& 
func,
+                              std::ostream& os) override;
   void InitFuncState(const PrimFunc& f) final;
   void PrintStorageScope(const std::string& scope, std::ostream& os) final;  
// NOLINT(*)
   void PrintStorageSync(const CallNode* op) final;                           
// NOLINT(*)
diff --git a/src/target/source/codegen_opencl.cc 
b/src/target/source/codegen_opencl.cc
index c15d2253d7..da6a4de619 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -595,18 +595,26 @@ runtime::Module BuildOpenCL(IRModule mod, Target target) {
   using tvm::runtime::Registry;
   bool output_ssa = false;
 
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only 
take PrimFunc";
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
+    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
+        << "CodeGenOpenCL: expect calling_conv equals 
CallingConv::kDeviceKernelLaunch";
+    functions.Set(gvar, prim_func);
+  }
+
   std::stringstream code;
   const auto* fpostproc = Registry::Get("tvm_callback_opencl_postproc");
-  for (auto kv : mod->functions) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only 
take PrimFunc";
-    code << "// Function: " << kv.first->name_hint << std::endl;
+  for (auto [gvar, prim_func] : functions) {
+    code << "// Function: " << gvar->name_hint << std::endl;
     CodeGenOpenCL cg;
     cg.Init(output_ssa);
-    auto f = Downcast<PrimFunc>(kv.second);
-    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
-    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
-        << "CodeGenOpenCL: expect calling_conv equals 
CallingConv::kDeviceKernelLaunch";
-    cg.AddFunction(f);
+    for (auto [other_gvar, other_prim_func] : functions) {
+      cg.DeclareFunction(other_gvar, other_prim_func);
+    }
+    cg.AddFunction(gvar, prim_func);
     std::string fsource = cg.Finish();
     if (fpostproc) {
       fsource = (*fpostproc)(fsource, target).operator std::string();
diff --git a/src/target/source/codegen_vhls.cc 
b/src/target/source/codegen_vhls.cc
index 83046de107..aa7a32320c 100644
--- a/src/target/source/codegen_vhls.cc
+++ b/src/target/source/codegen_vhls.cc
@@ -145,13 +145,21 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) 
{
   // Generate source code for get_source().
   cg.Init(output_ssa);
 
-  for (auto kv : mod->functions) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only 
take PrimFunc";
-    auto f = Downcast<PrimFunc>(kv.second);
-    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only 
take PrimFunc";
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
     ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
         << "CodeGenVLHS: expect calling_conv equals 
CallingConv::kDeviceKernelLaunch";
-    cg.AddFunction(f);
+    functions.Set(gvar, prim_func);
+  }
+
+  for (auto [gvar, prim_func] : functions) {
+    cg.DeclareFunction(gvar, prim_func);
+  }
+  for (auto [gvar, prim_func] : functions) {
+    cg.AddFunction(gvar, prim_func);
   }
 
   std::string whole_code = cg.Finish();
@@ -159,21 +167,21 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) 
{
   // Generate source code for compilation.
   Array<Array<runtime::String>> kernel_info;
 
-  for (auto kv : mod->functions) {
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only 
take PrimFunc";
-    auto f = Downcast<PrimFunc>(kv.second);
+  for (auto [gvar, prim_func] : functions) {
     CodeGenVivadoHLS cg;
     cg.Init(output_ssa);
-    cg.AddFunction(f);
+
+    for (auto [other_gvar, other_prim_func] : functions) {
+      cg.DeclareFunction(other_gvar, other_prim_func);
+    }
+    cg.AddFunction(gvar, prim_func);
     std::string code = cg.Finish();
     if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) {
       code = (*f)(code, target).operator std::string();
     }
 
-    auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-    ICHECK(global_symbol.defined())
-        << "CodeGenC: Expect PrimFunc to have the global_symbol attribute";
-    kernel_info.push_back({global_symbol.value(), code});
+    auto function_name = cg.GetFunctionName(gvar);
+    kernel_info.push_back({function_name, code});
   }
 
   std::string xclbin;
diff --git a/src/target/source/codegen_webgpu.cc 
b/src/target/source/codegen_webgpu.cc
index 4d1d834c7f..6a6712a4ce 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -45,6 +45,12 @@ std::string CodeGenWebGPU::Finish() {
 
 void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
   CodeGenC::InitFuncState(f);
+  // skip the first underscore, so SSA variable starts from
+  name_supply_->FreshName("v_");
+  // Setup the thread group info.
+  ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
+  ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
+
   // analyze the data;
   for (Var arg : f->params) {
     if (arg.dtype().is_handle()) {
@@ -56,28 +62,12 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
 
 CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
 
-void CodeGenWebGPU::AddFunction(const PrimFunc& f) {
-  // clear previous generated state.
-  this->InitFuncState(f);
-  // skip the first underscore, so SSA variable starts from
-  name_supply_->FreshName("v_");
-  // Setup the thread group info.
-  ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx");
-  ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx");
-
-  // add to alloc buffer type.
-  auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-  ICHECK(global_symbol.defined())
-      << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute";
-
-  decl_stream << "//----------------------------------------\n"
-              << "// function: " << global_symbol.value() << "\n"
-              << "//----------------------------------------\n";
-
+void CodeGenWebGPU::PrintFunctionSignature(const String& function_name, const 
PrimFunc& func,
+                                           std::ostream& os) {
   std::vector<Var> pod_args;
   int num_buffer = 0;
   // setup buffer argumemts
-  for (Var arg : f->params) {
+  for (Var arg : func->params) {
     DataType t = arg.dtype();
     if (t.is_handle()) {
       auto* ptr = arg->type_annotation.as<PointerTypeNode>();
@@ -111,16 +101,18 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) {
   }
   // add to alloc buffer type.
   // Function header.
-  this->stream << "fn main(\n"
-               << "  @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
-               << "  @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
-               << ") {\n";
-  // the function scope.
-  int func_scope = this->BeginScope();
-  this->PrintStmt(f->body);
-  this->EndScope(func_scope);
-  this->PrintIndent();
-  this->stream << "}\n\n";
+  os << "fn main(\n"
+     << "  @builtin(workgroup_id) blockIdx : vec3<u32>,\n"
+     << "  @builtin(local_invocation_id) threadIdx : vec3<u32>\n"
+     << ")";
+}
+
+void CodeGenWebGPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) {
+  CodeGenC::AddFunction(gvar, func);
+  decl_stream << "//----------------------------------------\n"
+              << "// function: " << GetFunctionName(gvar) << "\n"
+              << "//----------------------------------------\n";
+
   // anotate workgroup
   this->fwd_decl_stream << "@compute @workgroup_size(" << workgroup_size_[0] 
<< ", "
                         << workgroup_size_[1] << ", " << workgroup_size_[2] << 
")\n";
@@ -524,22 +516,31 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) {
   mod = tir::transform::PointerValueTypeRewrite()(std::move(mod));
   bool output_ssa = false;
 
-  std::unordered_map<std::string, std::string> smap;
-  for (auto kv : mod->functions) {
-    CodeGenWebGPU cg(target);
-    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only 
take PrimFunc";
-    auto f = Downcast<PrimFunc>(kv.second);
-    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
+  Map<GlobalVar, PrimFunc> functions;
+  for (auto [gvar, base_func] : mod->functions) {
+    ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only 
take PrimFunc";
+    auto prim_func = Downcast<PrimFunc>(base_func);
+    auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv);
     ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
         << "CodeGenWebGPU: expect calling_conv equals 
CallingConv::kDeviceKernelLaunch";
-    auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+    auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
     ICHECK(global_symbol.defined())
         << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol 
attribute";
-    std::string f_name = global_symbol.value();
+    functions.Set(gvar, prim_func);
+  }
+
+  std::unordered_map<std::string, std::string> smap;
+  for (auto [gvar, prim_func] : functions) {
+    CodeGenWebGPU cg(target);
     cg.Init(output_ssa);
-    cg.AddFunction(f);
+
+    for (auto [other_gvar, other_prim_func] : functions) {
+      cg.DeclareFunction(other_gvar, other_prim_func);
+    }
+    cg.AddFunction(gvar, prim_func);
+
     std::string code = cg.Finish();
-    smap[f_name] = code;
+    smap[cg.GetFunctionName(gvar)] = code;
   }
   auto n = make_object<WebGPUSourceModuleNode>(smap, ExtractFuncInfo(mod));
   return runtime::Module(n);
diff --git a/src/target/source/codegen_webgpu.h 
b/src/target/source/codegen_webgpu.h
index 57f226ba8a..6ae942a3ad 100644
--- a/src/target/source/codegen_webgpu.h
+++ b/src/target/source/codegen_webgpu.h
@@ -48,7 +48,9 @@ class CodeGenWebGPU final : public CodeGenC {
   explicit CodeGenWebGPU(Target target);
   // overrides
   std::string Finish() final;
-  void AddFunction(const PrimFunc& f);  // NOLINT(*)
+  void PrintFunctionSignature(const String& function_name, const PrimFunc& 
func,
+                              std::ostream& os) final;
+  void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final;
   void InitFuncState(const PrimFunc& f) final;
   void PrintStorageSync(const CallNode* op) final;     // NOLINT(*)
   void PrintType(DataType t, std::ostream& os) final;  // NOLINT(*)
diff --git a/src/target/source/source_module.cc 
b/src/target/source/source_module.cc
index a6f4b5bb3e..90640a6db6 100644
--- a/src/target/source/source_module.cc
+++ b/src/target/source/source_module.cc
@@ -613,12 +613,14 @@ class CSourceCrtMetadataModuleNode : public 
runtime::ModuleNode {
       }
 
       for (const tir::Var& pool_var : metadata_->pools) {
+        call_args_ss << "((uint8_t*)";
         String pool_name = 
metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name;
         if (IsInternalWorkspaceBuffer(pool_var)) {
-          call_args_ss << "&" << pool_name << ",";
+          call_args_ss << "&" << pool_name;
         } else {
-          call_args_ss << "workspace_pools->" << 
tvm::runtime::SanitizeName(pool_name) << ",";
+          call_args_ss << "workspace_pools->" << 
tvm::runtime::SanitizeName(pool_name);
         }
+        call_args_ss << "),";
       }
       for (const String& device : metadata_->devices) {
         call_args_ss << "devices->" << device << ",";
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 39214c4546..fd14f48921 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -70,6 +70,32 @@ Type GetType(const PrimExpr& expr) {
       return ptr->type_annotation;
     }
   }
+
+  if (auto* access = expr.as<tir::CallNode>()) {
+    if (access->op.same_as(builtin::tvm_access_ptr())) {
+      ICHECK(access->args.size()) << "Builtin tvm_access_ptr() may not have 
empty arguments";
+      auto type_annotation = Downcast<Call>(access->args[0]);
+      static auto builtin_op = Op::Get("tir.type_annotation");
+      ICHECK(type_annotation->op.same_as(builtin_op))
+          << "Expected the first argument of builtin tvm_access_ptr() "
+          << "to be a type annotation, but found " << type_annotation->op;
+      return PointerType(PrimType(type_annotation->dtype));
+    }
+  }
+
+  if (auto* address_of = expr.as<tir::CallNode>()) {
+    if (address_of->op.same_as(builtin::address_of())) {
+      ICHECK_EQ(address_of->args.size(), 1)
+          << "Builtin address_of() expects a single argument, but received 
arguments "
+          << address_of->args;
+      auto* address = address_of->args[0].as<BufferLoadNode>();
+      ICHECK(address)
+          << "Builtin address_of() expects the argument to be a BufferLoad, 
but received argument "
+          << address_of->args[0];
+
+      return PointerType(PrimType(address->dtype));
+    }
+  }
   // Default: return the type indicated by the dtype.
   runtime::DataType dtype = expr.dtype();
   return GetTypeFromRuntimeDataType(dtype);
diff --git a/tests/python/relay/aot/test_crt_forward_declarations.py 
b/tests/python/relay/aot/test_crt_forward_declarations.py
index e001a62ab9..99e2f0c923 100644
--- a/tests/python/relay/aot/test_crt_forward_declarations.py
+++ b/tests/python/relay/aot/test_crt_forward_declarations.py
@@ -33,8 +33,6 @@ from tvm.micro.testing.aot_test_utils import (
     AOTTestRunner,
 )
 
-pytestmark = pytest.mark.skip(reason="regression introduced in #15725")
-
 
 def _change_ndarray_layout(arr, src_layout, dst_layout):
     """Makes a copy of an ndarray, reshaping it to a new data layout.
@@ -162,8 +160,8 @@ def test_internal_calls(interface_api, use_unpacked_api, 
test_runner):
 
     lib_mod = compiled_models[0].executor_factory.lib.imported_modules[0]
     main_source = lib_mod.get_source()
-    assert main_source.count("int32_t 
tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 1
-    assert main_source.count("int32_t tvmgen_default_fused_layout_transform") 
== 3
+    assert main_source.count("int32_t 
tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 2
+    assert main_source.count("int32_t tvmgen_default_fused_layout_transform") 
== 6
 
 
 @tvm.testing.requires_corstone300
diff --git a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py 
b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
index 7bea7577b6..f6145cd1c5 100644
--- a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
+++ b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py
@@ -135,8 +135,13 @@ def test_write_3x3_depthwise_code():
     #define TENSORDOT_OPT_X1_INT16_W48_3X3_000_EXISTS
     #include <arm_acle.h>
     __attribute__((always_inline)) static inline int32_t 
tensordot_opt_x1_int16_w48_3x3_000(
-        int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, 
int32_t *scale
+        int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+        int32_t *bias, int32_t *scale
     ) {
+      int32_t *output = output_arg;
+      int32_t *tensor = tensor_arg;
+      int32_t *kernel = kernel_arg;
+
       int32_t sum_0 = *bias;
 
       int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -188,8 +193,13 @@ def test_odd_width_3x3_depthwise_strides_code():
     #define TENSORDOT_OPT_X2_INT16_W49_3X3_000_2_4_EXISTS
     #include <arm_acle.h>
     __attribute__((always_inline)) static inline int32_t 
tensordot_opt_x2_int16_w49_3x3_000_2_4(
-        int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, 
int32_t *scale
+        int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+        int32_t *bias, int32_t *scale
     ) {
+      int32_t *output = output_arg;
+      int32_t *tensor = tensor_arg;
+      int32_t *kernel = kernel_arg;
+
       int32_t sum_0 = *bias, sum_1 = *bias;
 
       int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -251,8 +261,13 @@ def test_1x1x8_convolution_code():
     #define TENSORDOT_OPT_X4_INT16_W384_1X8_000_8_1_EXISTS
     #include <arm_acle.h>
     __attribute__((always_inline)) static inline int32_t 
tensordot_opt_x4_int16_w384_1x8_000_8_1(
-        int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, 
int32_t *scale
+        int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+        int32_t *bias, int32_t *scale
     ) {
+      int32_t *output = output_arg;
+      int32_t *tensor = tensor_arg;
+      int32_t *kernel = kernel_arg;
+
       int32_t sum_0 = *bias, sum_1 = *bias, sum_2 = *bias, sum_3 = *bias;
 
       int32_t tensor__y00_x00__y00_x01 = tensor[0];
@@ -349,8 +364,13 @@ def test_3x3x3_offset_convolution_code():
     #define TENSORDOT_OPT_X1_INT16_W288_3X9_111_EXISTS
     #include <arm_acle.h>
     __attribute__((always_inline)) static inline int32_t 
tensordot_opt_x1_int16_w288_3x9_111(
-        int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, 
int32_t *scale
+        int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg,
+        int32_t *bias, int32_t *scale
     ) {
+      int32_t *output = output_arg;
+      int32_t *tensor = tensor_arg;
+      int32_t *kernel = kernel_arg;
+
       int32_t sum_0 = *bias;
 
       int32_t tensor__unknown__y00_x00 = tensor[0];
diff --git a/tests/python/unittest/test_target_codegen_c_host.py 
b/tests/python/unittest/test_target_codegen_c_host.py
index d02f8744f1..3aca0fc8c7 100644
--- a/tests/python/unittest/test_target_codegen_c_host.py
+++ b/tests/python/unittest/test_target_codegen_c_host.py
@@ -14,11 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
 import tvm
 import tvm.testing
+
 from tvm import te
-import numpy as np
 from tvm.contrib import utils
+from tvm.script import tir as T, ir as I
+
+import numpy as np
 
 
 def test_add():
@@ -228,11 +232,39 @@ def test_call_packed():
     check_global_packed_func()
 
 
+def test_subroutine_call():
+    @I.ir_module
+    class mod:
+        @T.prim_func
+        def main(A: T.Buffer(1, dtype="float32")):
+            mod.subroutine(A.data)
+
+        @T.prim_func(private=True)
+        def subroutine(A_data: T.handle("float32")):
+            A = T.decl_buffer(1, dtype="float32", data=A_data)
+            A[0] = 42.0
+
+    built = tvm.build(mod, target="c")
+
+    func_names = list(built["get_func_names"]())
+    assert (
+        "main" in func_names
+    ), "Externally exposed functions should be listed in available functions."
+    assert (
+        "subroutine" not in func_names
+    ), "Internal function should not be listed in available functions."
+
+    source = built.get_source()
+    assert (
+        source.count("main(void*") == 2
+    ), "Expected two occurrences, for forward-declaration and definition"
+    assert (
+        source.count("subroutine(float*") == 2
+    ), "Expected two occurrences, for forward-declaration and definition"
+    assert (
+        source.count("subroutine(") == 3
+    ), "Expected three occurrences, for forward-declaration, definition, and 
call from main."
+
+
 if __name__ == "__main__":
-    test_add()
-    test_add_pipeline()
-    test_reinterpret()
-    test_ceil()
-    test_floor()
-    test_round()
-    test_call_packed()
+    tvm.testing.main()
diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py 
b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
index 588a92d87c..61f0892a9c 100644
--- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
+++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py
@@ -268,6 +268,7 @@ cast_smem_ptr_to_int(const void* const smem_ptr)
   #define int64_t long long
   #define uint64_t unsigned long long
 #endif
+extern "C" __global__ void __launch_bounds__(16) main_kernel(float* 
__restrict__ A, float* __restrict__ B, float* __restrict__ C);
 extern "C" __global__ void __launch_bounds__(16) main_kernel(float* 
__restrict__ A, float* __restrict__ B, float* __restrict__ C) {
   __shared__ float A_shared[64];
   __shared__ float B_shared[64];

Reply via email to