This is an automated email from the ASF dual-hosted git repository. csullivan pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 698531e6d7 [CodeGenC][Redo] Handle GlobalVar callee as internal function call (#15835) 698531e6d7 is described below commit 698531e6d7e1493b7a73c71137132de87de8aad0 Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Wed Oct 18 14:07:43 2023 -0500 [CodeGenC][Redo] Handle GlobalVar callee as internal function call (#15835) * [CodeGenC][Redo] Handle GlobalVar callee as internal function call This reverts commit [`e88d0d`](https://github.com/apache/tvm/pull/15725), which itself reverted [`9ff71f`](https://github.com/apache/tvm/pull/15103) for breakages on the metal backend. Now that the CI contains compile-time testing of the metal codegen, the original breakage should be identifiable. * Added codegen metal CI debug print * Print function decl to the argument stream * Remove the codegen metal CI debug print-outs --- .../arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py | 8 +- .../topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py | 87 +++++++++--- .../arm_cpu/mprofile/dsp/micro_kernel/max_pool.py | 13 +- .../arm_cpu/mprofile/dsp/micro_kernel/tensordot.py | 7 +- .../backend/contrib/cmsisnn/tir_to_runtime.cc | 28 ++-- .../contrib/example_target_hooks/tir_to_runtime.cc | 26 +++- src/relay/backend/contrib/uma/tir_to_runtime.cc | 34 +++-- src/target/opt/build_cuda_on.cc | 18 ++- src/target/source/codegen_aocl.cc | 19 ++- src/target/source/codegen_c.cc | 153 ++++++++++++++------- src/target/source/codegen_c.h | 59 +++++++- src/target/source/codegen_c_host.cc | 93 ++++++------- src/target/source/codegen_c_host.h | 3 +- src/target/source/codegen_cuda.cc | 4 +- src/target/source/codegen_cuda.h | 2 +- src/target/source/codegen_metal.cc | 89 ++++++------ src/target/source/codegen_metal.h | 3 +- src/target/source/codegen_opencl.cc | 24 ++-- src/target/source/codegen_vhls.cc | 34 +++-- src/target/source/codegen_webgpu.cc | 79 +++++------ src/target/source/codegen_webgpu.h | 4 +- src/target/source/source_module.cc | 6 +- src/tir/op/op.cc | 26 ++++ .../relay/aot/test_crt_forward_declarations.py | 6 +- .../topi/python/test_topi_conv2d_tensordot_opts.py | 28 +++- .../python/unittest/test_target_codegen_c_host.py | 48 +++++-- .../test_tir_transform_inject_ptx_async_copy.py | 1 + 27 files changed, 597 insertions(+), 305 deletions(-) diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py index e8e45152aa..3eb32d8fdb 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/avg_pool.py @@ -55,7 +55,7 @@ def intrin_sum(shape, in_dtype, out_dtype, reset=False): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( - cc.dtype, + "int32", f"{func_prefix}_{width}_{uniq_id}", aa.access_ptr("r"), cc.access_ptr("w"), @@ -68,7 +68,7 @@ def intrin_sum(shape, in_dtype, out_dtype, reset=False): def _reduce_reset(): ib = tvm.tir.ir_builder.create() ib.emit( - tvm.tir.call_extern(cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w")) + tvm.tir.call_extern("int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w")) ) return ib.get() @@ -113,8 +113,8 @@ extern "C" __attribute__((always_inline)) static inline int32_t sum16_{N}_{uniq_id}( int16_t *arr, int16_t *res16, - long arr_offset, - int reset) {{ + int32_t arr_offset, + int32_t reset) {{ int n; int32_t *p32; int32_t res = reset ? 0 : *res16; diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py index 929dcc6557..e26e818fbd 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/gemm.py @@ -156,9 +156,14 @@ __attribute__((always_inline)) static inline const int8_t *read_and_pad(const in extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_body_rest_{uniq_id}( - int K, + int32_t K_arg, int8_t *aa, int8_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int K = K_arg; + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int k_base = (K / 4) * 4; switch ( K % 4 ) {{ case 1: @@ -200,7 +205,12 @@ extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_loop_{uniq_id}( int8_t *aa, int8_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + + for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ int32_t sum = 0; @@ -221,7 +231,11 @@ extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_body_{uniq_id}( int8_t *aa, int8_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int16_t bb_pad[{bb_pad_size}]; int32_t retcode = 0; @@ -265,9 +279,14 @@ out: extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{N}_update_rest_{uniq_id}( - int K, + int32_t K_arg, int8_t *aa, int8_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int K = K_arg; + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int k_base = (K / 4) * 4; switch ( K % 4 ) {{ case 1: @@ -309,7 +328,11 @@ extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_loop_{uniq_id}( int8_t *aa, int8_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ int32_t sum = 0; @@ -327,7 +350,11 @@ extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_update_{uniq_id}( int8_t *aa, int8_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int16_t bb_pad[{bb_pad_size}]; int32_t retcode = 0; @@ -368,9 +395,14 @@ out: extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_body_rest_{uniq_id}( - int K, + int32_t K_arg, int16_t *aa, int16_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int K = K_arg; + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int k_base = (K / 2) * 2; for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ @@ -387,7 +419,11 @@ extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_loop_{uniq_id}( int16_t *aa, int16_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ int32_t sum = 0; @@ -408,7 +444,11 @@ extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_body_{uniq_id}( int16_t *aa, int16_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int32_t retcode = 0; if ( {M} < 2 && {N} < 2 ) {{ @@ -450,9 +490,14 @@ out: extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{N}_update_rest_{uniq_id}( - int K, + int32_t K_arg, int16_t *aa, int16_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int K = K_arg; + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int k_base = (K / 2) * 2; for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ @@ -469,7 +514,11 @@ extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_loop_{uniq_id}( int16_t *aa, int16_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ int32_t sum = 0; @@ -487,7 +536,11 @@ extern "C" #endif __attribute__((always_inline)) static inline int32_t gemm16_{M}x{K}x{N}_update_{uniq_id}( int16_t *aa, int16_t *bb, int32_t *cc, - int A_stride, int B_stride, int C_stride) {{ + int32_t A_stride_arg, int32_t B_stride_arg, int32_t C_stride_arg) {{ + int A_stride = A_stride_arg; + int B_stride = B_stride_arg; + int C_stride = C_stride_arg; + int32_t retcode = 0; if ( {M} < 2 && {N} < 2 ) {{ @@ -520,7 +573,7 @@ out: #ifdef __cplusplus extern "C" #endif -__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int C_stride) {{ +__attribute__((always_inline)) static inline int32_t gemm_{M}x{K}x{N}_reset_{uniq_id}(int32_t *cc, int32_t C_stride) {{ for (int i = 0; i < {M}; i++) {{ for (int j = 0; j < {N}; j++) {{ cc[i*C_stride + j] = 0; diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py index 66d712a4a0..cfed417c9f 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/max_pool.py @@ -46,7 +46,7 @@ def intrin_max(shape, in_dtype, out_dtype): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( - cc.dtype, + "int32", f"{func_prefix}_{uniq_id}", aa.access_ptr("r"), cc.access_ptr("w"), @@ -59,7 +59,7 @@ def intrin_max(shape, in_dtype, out_dtype): ib = tvm.tir.ir_builder.create() ib.emit( tvm.tir.call_extern( - cc.dtype, f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0] + "int32", f"{func_prefix}_reset_{uniq_id}", cc.access_ptr("w"), cc.strides[0] ) ) return ib.get() @@ -96,7 +96,7 @@ extern "C" #endif __attribute__((always_inline)) static inline int32_t max8_reset_{uniq_id}( int8_t *res, - int N) {{ + int32_t N) {{ memset(res, (int8_t)-128, N * sizeof(*res)); return 0; }} @@ -107,7 +107,9 @@ extern "C" __attribute__((always_inline)) static inline int32_t max8_loop_{uniq_id}( int8_t *arg, int8_t *res, - int N) {{ + int32_t N_arg) {{ + int N = N_arg; + for ( int i = 0; i < N; ++ i ) if ( arg[i] > res[i] ) res[i] = arg[i]; @@ -120,7 +122,8 @@ extern "C" __attribute__((always_inline)) static inline int32_t max8_{uniq_id}( int8_t *arg, int8_t *res, - int N) {{ + int32_t N_arg) {{ + int N = N_arg; int32_t *parg32, *pres32; int una_arg = (int32_t)arg & 0x3, una_res = (int32_t)res & 0x3; int32_t retcode = 0; diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py index d2a8f1ef69..af3b23e01d 100644 --- a/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py +++ b/python/tvm/topi/arm_cpu/mprofile/dsp/micro_kernel/tensordot.py @@ -390,8 +390,13 @@ def tensordot_int16_impl( #define {function_name.upper()}_EXISTS #include <arm_acle.h> __attribute__((always_inline)) static inline int32_t {function_name}( - int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale + int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, + int32_t *bias, int32_t *scale ) {{ + int32_t *output = output_arg; + int32_t *tensor = tensor_arg; + int32_t *kernel = kernel_arg; + {_init_biased_accumulators(num_outputs)} {insert_lines(load_tensor_lines)} diff --git a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc index 186fa30f20..6febfe3486 100644 --- a/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc +++ b/src/relay/backend/contrib/cmsisnn/tir_to_runtime.cc @@ -46,13 +46,6 @@ class CodeGenCMSISNN : public codegen::CodeGenCHost { CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str, devices); } - /*! - * \brief Emit code that offloads a subgraph to the Cortex-M - * - * \return string of code that offloads a subgraph to the Cortex-M - */ - void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); } - private: /*! * \brief Enable storing the last error */ bool debug_last_error; @@ -575,11 +568,11 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { bool emit_fwd_func_decl = false; bool debug_last_error = GetCompilerAttrs()->debug_last_error; CodeGenCMSISNN codegen; - Array<String> function_names; codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), debug_last_error); - std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs; - for (auto kv : mod->functions) { - funcs.push_back(kv); + + std::vector<std::pair<tvm::GlobalVar, tvm::PrimFunc>> funcs; + for (auto [gvar, base_func] : mod->functions) { + funcs.push_back({gvar, Downcast<PrimFunc>(base_func)}); } std::sort(funcs.begin(), funcs.end(), @@ -594,13 +587,16 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { return name_hint_a < name_hint_b; }); - for (auto kv : funcs) { - auto prim_func = Downcast<PrimFunc>(kv.second); - auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol); - function_names.push_back(global_symbol.value()); - codegen.AddFunction(prim_func); + for (auto [gvar, prim_func] : funcs) { + codegen.AddFunction(gvar, prim_func); } std::string code = codegen.Finish(); + + Array<String> function_names; + for (auto [gvar, prim_func] : funcs) { + function_names.push_back(codegen.GetFunctionName(gvar)); + } + return codegen::CSourceModuleCreate(code, "c", function_names); } diff --git a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc index 0db8d06c31..6f09e0a0c3 100644 --- a/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc +++ b/src/relay/backend/contrib/example_target_hooks/tir_to_runtime.cc @@ -49,16 +49,30 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) { bool emit_asserts = false; bool emit_fwd_func_decl = false; CodeGenExampleTargetHook codegen; - Array<String> function_names; + std::unordered_set<std::string> devices; codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices); - for (auto kv : mod->functions) { - auto prim_func = Downcast<PrimFunc>(kv.second); - auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol); - function_names.push_back(global_symbol.value()); - codegen.AddFunction(prim_func); + + Map<GlobalVar, PrimFunc> functions; + for (auto [gvar, base_func] : mod->functions) { + auto prim_func = Downcast<PrimFunc>(base_func); + functions.Set(gvar, prim_func); + } + + for (auto [gvar, prim_func] : functions) { + codegen.DeclareFunction(gvar, prim_func); + } + for (auto [gvar, prim_func] : functions) { + codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl); } + std::string code = codegen.Finish(); + + Array<String> function_names; + for (auto [gvar, prim_func] : functions) { + function_names.push_back(codegen.GetFunctionName(gvar)); + } + return codegen::CSourceModuleCreate(code, "c", function_names); } diff --git a/src/relay/backend/contrib/uma/tir_to_runtime.cc b/src/relay/backend/contrib/uma/tir_to_runtime.cc index 3b58fda54b..487e247f5d 100644 --- a/src/relay/backend/contrib/uma/tir_to_runtime.cc +++ b/src/relay/backend/contrib/uma/tir_to_runtime.cc @@ -49,13 +49,6 @@ class UMACodegen : public codegen::CodeGenCHost { CodeGenCHost::Init(output_ssa, emit_asserts, emit_fwd_func_decl, target_str_, devices); } - /*! - * \brief Emit code that offloads a subgraph to the UMA target - * - * \return string of code that offloads a subgraph to the UMA target - */ - void AddFunction(const PrimFunc& prim_func) { CodeGenC::AddFunction(prim_func); } - private: String target_str_; }; @@ -63,17 +56,30 @@ class UMACodegen : public codegen::CodeGenCHost { runtime::Module TIRToRuntime(IRModule mod, Target target) { bool output_ssa = false; bool emit_asserts = false; - bool emit_fwd_func_decl = false; + bool emit_fwd_func_decl = true; UMACodegen codegen(target->kind->name); - Array<String> function_names; codegen.Init(output_ssa, emit_asserts, emit_fwd_func_decl); - for (auto kv : mod->functions) { - auto prim_func = Downcast<PrimFunc>(kv.second); - auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol); - function_names.push_back(global_symbol.value()); - codegen.AddFunction(prim_func); + + Map<GlobalVar, PrimFunc> functions; + for (auto [gvar, base_func] : mod->functions) { + auto prim_func = Downcast<PrimFunc>(base_func); + functions.Set(gvar, prim_func); + } + + for (auto [gvar, prim_func] : functions) { + codegen.DeclareFunction(gvar, prim_func); + } + for (auto [gvar, prim_func] : functions) { + codegen.AddFunction(gvar, prim_func, emit_fwd_func_decl); } + std::string code = codegen.Finish(); + + Array<String> function_names; + for (auto [gvar, prim_func] : functions) { + function_names.push_back(codegen.GetFunctionName(gvar)); + } + return codegen::CSourceModuleCreate(code, "c", function_names); } diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 1c0b5094ef..e0f53e3509 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -131,13 +131,21 @@ runtime::Module BuildCUDA(IRModule mod, Target target) { CodeGenCUDA cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc"; - auto f = Downcast<PrimFunc>(kv.second); - auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); + Map<GlobalVar, PrimFunc> functions; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc"; + auto prim_func = Downcast<PrimFunc>(base_func); + auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - cg.AddFunction(f); + functions.Set(gvar, prim_func); + } + + for (auto [gvar, prim_func] : functions) { + cg.DeclareFunction(gvar, prim_func); + } + for (auto [gvar, prim_func] : functions) { + cg.AddFunction(gvar, prim_func); } std::string code = cg.Finish(); diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index 700d85b4cc..dc3ba08751 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -40,13 +40,22 @@ runtime::Module BuildAOCL(IRModule mod, Target target, bool emulation) { CodeGenOpenCL cg; cg.Init(output_ssa); - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenOpenCL: Can only take PrimFunc"; - auto f = Downcast<PrimFunc>(kv.second); - auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); + Map<GlobalVar, PrimFunc> functions; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodegenOpenCL: Can only take PrimFunc"; + auto prim_func = Downcast<PrimFunc>(base_func); + auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - cg.AddFunction(f); + functions.Set(gvar, prim_func); + } + + for (auto [gvar, prim_func] : functions) { + cg.DeclareFunction(gvar, prim_func); + } + + for (auto [gvar, prim_func] : functions) { + cg.AddFunction(gvar, prim_func); } std::string code = cg.Finish(); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index a7cc320562..187bdc74fe 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -42,6 +42,7 @@ void CodeGenC::InitFuncState(const PrimFunc& f) { alloc_storage_scope_.clear(); handle_data_type_.clear(); CodeGenSourceBase::ClearFuncState(); + ReserveKeywordsAsUnique(); } void CodeGenC::ReserveKeywordsAsUnique() { @@ -75,51 +76,92 @@ void CodeGenC::ReserveKeywordsAsUnique() { name_supply_->ReserveName("return"); } -void CodeGenC::AddFunction(const PrimFunc& f) { - // clear previous generated state. - this->InitFuncState(f); - // reserve keywords - ReserveKeywordsAsUnique(); +void CodeGenC::PrintFunctionSignature(const String& function_name, const PrimFunc& func, + std::ostream& os) { + PrintFuncPrefix(os); + PrintType(func->ret_type, os); + PrintExtraAttrs(func, os); + os << " " << function_name << "("; + for (size_t i = 0; i < func->params.size(); ++i) { + tir::Var v = func->params[i]; - auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; - bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); - - this->PrintFuncPrefix(stream); - PrintType(f->ret_type, stream); - this->PrintExtraAttrs(f); - this->stream << " " << static_cast<std::string>(global_symbol.value()) << "("; - - for (size_t i = 0; i < f->params.size(); ++i) { - tir::Var v = f->params[i]; - std::string vid = AllocVarID(v.get()); - if (i != 0) stream << ", "; - if (v.dtype().is_handle()) { - auto it = alloc_storage_scope_.find(v.get()); - if (it != alloc_storage_scope_.end()) { - PrintStorageScope(it->second, stream); - } + if (i > 0) { + os << ", "; + } - PrintType(GetType(v), stream); - // Register handle data type - // TODO(tvm-team): consider simply keep type info in the - // type annotation(via a normalizing rewriting). - if (auto* ptr = v->type_annotation.as<PointerTypeNode>()) { - if (auto* prim = ptr->element_type.as<PrimTypeNode>()) { - RegisterHandleType(v.get(), prim->dtype); - } - } + if (auto it = alloc_storage_scope_.find(v.get()); it != alloc_storage_scope_.end()) { + PrintStorageScope(it->second, os); + } - if (no_alias) { - PrintRestrict(v, stream); + PrintType(GetType(v), os); + + bool no_alias = func->HasNonzeroAttr(tir::attr::kNoAlias); + bool is_handle = v.dtype().is_handle(); + if (no_alias && is_handle) { + PrintRestrict(v, os); + } + + os << " " << AllocVarID(v.get()); + } + os << ")"; + + // Register handle data type + // TODO(tvm-team): consider simply keep type info in the + // type annotation(via a normalizing rewriting). + for (const auto& param : func->params) { + if (auto* ptr = param->type_annotation.as<PointerTypeNode>()) { + if (auto* prim = ptr->element_type.as<PrimTypeNode>()) { + RegisterHandleType(param.get(), prim->dtype); } - } else { - PrintType(GetType(v), stream); } - stream << ' ' << vid; } - stream << ") {\n"; +} + +void CodeGenC::DeclareFunction(const GlobalVar& gvar, const PrimFunc& func) { + if (internal_functions_.count(gvar)) { + return; + } + + auto function_name = [&]() -> String { + if (auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) { + auto name = global_symbol.value(); + ICHECK(!func_name_supply_->ContainsName(name)) + << "Function " << gvar << " must use global symbol " << name + << ", but this name has already been used."; + func_name_supply_->ReserveName(name); + return name; + } else { + func_name_supply_->ReserveName(gvar->name_hint); + return gvar->name_hint; + } + }(); + + internal_functions_.insert({gvar, function_name}); + + InitFuncState(func); + PrintFunctionSignature(function_name, func, fwd_decl_stream); + fwd_decl_stream << ";\n"; +} + +String CodeGenC::GetFunctionName(const GlobalVar& gvar) { + auto it = internal_functions_.find(gvar); + ICHECK(it != internal_functions_.end()) + << "Attempted to find name of " << gvar + << ", but no function with this GlobalVar has been declared"; + return it->second; +} + +void CodeGenC::AddFunction(const GlobalVar& gvar, const PrimFunc& f) { + // If the function has already been forward-declared, this is a + // no-op. + DeclareFunction(gvar, f); + auto function_name = GetFunctionName(gvar); + + // clear previous generated state. + InitFuncState(f); + + PrintFunctionSignature(function_name, f, stream); + stream << " {\n"; this->PreFunctionBody(f); int func_scope = this->BeginScope(); this->PrintStmt(f->body); @@ -130,9 +172,15 @@ void CodeGenC::AddFunction(const PrimFunc& f) { void CodeGenC::PrintFuncPrefix(std::ostream& os) {} -void CodeGenC::PrintExtraAttrs(const PrimFunc& f) {} +void CodeGenC::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) {} -std::string CodeGenC::Finish() { return decl_stream.str() + stream.str(); } +std::string CodeGenC::Finish() { + std::ostringstream code; + code << decl_stream.str(); + code << fwd_decl_stream.str(); + code << stream.str(); + return code.str(); +} void CodeGenC::PrintExpr(const PrimExpr& n, std::ostream& os) { // NOLINT(*) if (print_ssa_form_) { @@ -542,12 +590,17 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) ICHECK_GE(op->args.size(), 1U); auto func = Downcast<StringImm>(op->args[0]); this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), func->value, op->args, true, os); - Array<Type> arg_types; - for (size_t i = 1; i < op->args.size(); i++) { - arg_types.push_back(GetType(op->args[i])); + + // If the call_extern refers to an function within the IRModule, then + // the forward declaration is already provided from DeclareFunction. + if (!func_name_supply_->ContainsName(func->value)) { + Array<Type> arg_types; + for (size_t i = 1; i < op->args.size(); i++) { + arg_types.push_back(GetType(op->args[i])); + } + Type ret_type = GetTypeFromRuntimeDataType(op->dtype); + this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type); } - Type ret_type = GetTypeFromRuntimeDataType(op->dtype); - this->GenerateForwardFunctionDeclarations(func->value, arg_types, ret_type); } else if (op_attr_global_symbol_.count(call_op)) { // call extern if the op itself have a global symbol. this->PrintCallExtern(GetType(GetRef<PrimExpr>(op)), op_attr_global_symbol_[call_op], @@ -615,9 +668,13 @@ void CodeGenC::VisitExpr_(const CallNode* op, std::ostream& os) { // NOLINT(*) } else { LOG(FATAL) << "Unresolved call " << op->op; } + } else if (auto opt = op->op.as<GlobalVar>()) { + auto gvar = opt.value(); + auto callee_name = GetFunctionName(gvar); + PrintCallExtern(GetType(GetRef<PrimExpr>(op)), callee_name, op->args, false, os); } else { - ICHECK(op->op.as<GlobalVarNode>()); - LOG(FATAL) << "Do not yet support cross function call"; + LOG(FATAL) << "CodeGenC: Unknown operation " << op->op << " is neither a recognized built-in, " + << "nor a GlobalVar reference to another function in the IRModule"; } } diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 93f9ea519c..2921a56ef3 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -65,12 +65,33 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>, * \param output_ssa Whether output SSA. */ void Init(bool output_ssa); + /*! - * \brief Add the function to the generated module. - * \param f The function to be compiled. + * \brief Add the function declaration to the generated module, + * without defining it. + * + * \param gvar The GlobalVar representing the function. + * \param func The function to be compiled. * \param whether to append return 0 in the end. */ - void AddFunction(const PrimFunc& f); + virtual void DeclareFunction(const GlobalVar& gvar, const PrimFunc& func); + + /*! + * \brief Add the function to the generated module, including its + * declaration and definition. + * + * \param gvar The GlobalVar representing the function. + * \param func The function to be compiled. + */ + virtual void AddFunction(const GlobalVar& gvar, const PrimFunc& func); + + /*! + * \brief Get the name of a declared function + * \param gvar The GlobalVar of the function + * \returns The string name of the function + */ + String GetFunctionName(const GlobalVar& gvar); + /*! * \brief Finalize the compilation and return the code. * \return The code. @@ -96,7 +117,23 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>, PrintExpr(n, os); return os.str(); } + // The following parts are overloadable print operations. + + /*! \brief Print the function signature before the argument list + * + * The default implementation delegates out to PrintFuncPrefix and + * PrintExtraAttrs. + * + * \param function_name The name of the function + * + * \param func The function whose signature should be printed + * + * \param os The output stream + */ + virtual void PrintFunctionSignature(const String& function_name, const PrimFunc& func, + std::ostream& os); + /*! * \brief Print the function header before the argument list * \param os The output stream @@ -109,7 +146,7 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>, * * Example: __launch_bounds__(256) for CUDA functions */ - virtual void PrintExtraAttrs(const PrimFunc& f); + virtual void PrintExtraAttrs(const PrimFunc& f, std::ostream& os); // NOLINT(*) /*! * \brief Insert statement before function body. * \param f The function to be compiled. @@ -284,10 +321,24 @@ class CodeGenC : public ExprFunctor<void(const PrimExpr&, std::ostream&)>, private: /*! \brief set of volatile buf access */ std::unordered_set<const VarNode*> volatile_buf_; + // deep comparison of PrimExpr ExprDeepEqual deep_equal_; + // binding of let variables. Enables duplicate var defs that map to same value std::unordered_map<Var, const LetNode*, ObjectPtrHash, ObjectPtrEqual> let_binding_; + + /* \brief Map of GlobalVar to their symbol. + * + * For externally-exposed functions, this is given by the + * tvm::attr::kTarget attribute of the PrimFunc. For internal + * functions, this is the name of the function's GlobalVar, possibly + * altered to prevent duplicate names. + */ + std::unordered_map<GlobalVar, String, ObjectPtrHash, ObjectPtrEqual> internal_functions_; + + /* \brief Name supply to generate unique function names */ + NameSupply func_name_supply_{""}; }; } // namespace codegen diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 3255e11c5d..caef43e8af 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -75,19 +75,24 @@ void CodeGenCHost::InitGlobalContext() { void CodeGenCHost::DefineModuleName() { decl_stream << "void* " << module_name_ << " = NULL;\n"; } -void CodeGenCHost::AddFunction(const PrimFunc& f, bool emit_fwd_func_decl) { - auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenCHost: Expect PrimFunc to have the global_symbol attribute"; - function_names_.push_back(global_symbol.value()); +void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, + bool emit_fwd_func_decl) { + auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol); + if (global_symbol) { + function_names_.push_back(global_symbol.value()); + } emit_fwd_func_decl_ = emit_fwd_func_decl; - CodeGenC::AddFunction(f); - if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + CodeGenC::AddFunction(gvar, func); + if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { + ICHECK(global_symbol.defined()) + << "CodeGenCHost: The entry func must have the global_symbol attribute, " + << "but function " << gvar << " only has attributes " << func->attrs; + function_names_.push_back(runtime::symbol::tvm_module_main); stream << "// CodegenC: NOTE: Auto-generated entry function\n"; PrintFuncPrefix(stream); - PrintType(f->ret_type, stream); + PrintType(func->ret_type, stream); stream << " " << tvm::runtime::symbol::tvm_module_main << "(void* args, int* arg_type_ids, int num_args, void* out_ret_value, " << "int* out_ret_tcode, void* resource_handle) {\n"; @@ -128,15 +133,6 @@ void CodeGenCHost::PrintFuncPrefix(std::ostream& os) { // NOLINT(*) << "TVM_DLL "; } -std::string CodeGenCHost::Finish() { // NOLINT(*) - std::string ret = decl_stream.str(); - if (emit_fwd_func_decl_) { - ret += fwd_decl_stream.str(); - } - ret += stream.str(); - return ret; -} - void CodeGenCHost::PrintType(DataType t, std::ostream& os) { // NOLINT(*) int lanes = t.lanes(); if (t.is_handle()) { @@ -437,42 +433,38 @@ runtime::Module BuildCHost(IRModule mod, Target target) { CodeGenCHost cg; cg.Init(output_ssa, emit_asserts, emit_fwd_func_decl, target->str(), devices); cg.SetConstantsByteAlignment(target->GetAttr<Integer>("constants-byte-alignment").value_or(16)); - PrimFunc aot_executor_fn; - - std::vector<std::pair<tvm::GlobalVar, tvm::BaseFunc>> funcs; - for (auto kv : mod->functions) { - // Make sure that the executor function is the last one to be code generated so that all the - // symbols are available to __tvm_main__ - auto fun_name = std::string(kv.first->name_hint); - bool is_aot_executor_fn = kv.second->GetAttr<Bool>("runner_function", Bool(false)).value(); - - if (is_aot_executor_fn) { - aot_executor_fn = Downcast<PrimFunc>(kv.second); - continue; - } - funcs.push_back(kv); + + auto is_aot_executor_fn = [](const PrimFunc& func) -> bool { + return func->GetAttr<Bool>("runner_function", Bool(false)).value(); + }; + + std::vector<std::pair<GlobalVar, PrimFunc>> funcs; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc"; + auto prim_func = Downcast<PrimFunc>(base_func); + funcs.push_back({gvar, prim_func}); } // Sort functions - std::sort(funcs.begin(), funcs.end(), - [](std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_a, - std::pair<tvm::GlobalVar, tvm::BaseFunc> kv_b) { - std::string name_hint_a = kv_a.first->name_hint; - std::string name_hint_b = kv_b.first->name_hint; - return name_hint_a < name_hint_b; - }); - - // Add all functions except __tvm_main__ - for (auto& kv : funcs) { - ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodegenCHost: Can only take PrimFunc"; - auto f = Downcast<PrimFunc>(kv.second); - cg.AddFunction(f); + auto sort_key = [&is_aot_executor_fn](const auto& kv) { + return std::tuple{is_aot_executor_fn(kv.second), kv.first->name_hint}; + }; + std::sort(funcs.begin(), funcs.end(), [&sort_key](const auto& kv_a, const auto& kv_b) { + return sort_key(kv_a) < sort_key(kv_b); + }); + + // Declare all functions first. This ensures that all functions, + // including the __tvm_main__ used in AOT, have access to forward + // declarations of other functions in the IRModule. + for (const auto& [gvar, prim_func] : funcs) { + cg.DeclareFunction(gvar, prim_func); } - // Add __tvm_main__ - if (aot_executor_fn.defined()) { - emit_fwd_func_decl = true; - cg.AddFunction(aot_executor_fn, emit_fwd_func_decl); + // Codegen all functions. Passing emit_fwd_func_decl=true adds a + // forward declaration for any `builtin::call_extern`, based on the + // arguments provided to it. + for (const auto& [gvar, prim_func] : funcs) { + cg.AddFunction(gvar, prim_func, emit_fwd_func_decl); } // NOTE: it's possible that kRuntime attr is not attached when the mod was built with tvm.build(). @@ -484,7 +476,10 @@ runtime::Module BuildCHost(IRModule mod, Target target) { } else { runtime = relay::Runtime::Create("cpp", {}); } - if (aot_executor_fn.defined() && runtime->name == relay::kTvmRuntimeCpp) { + + bool has_aot_executor_fn = std::any_of( + funcs.begin(), funcs.end(), [&](const auto& kv) { return is_aot_executor_fn(kv.second); }); + if (has_aot_executor_fn && runtime->name == relay::kTvmRuntimeCpp) { cg.InitGlobalContext(); } diff --git a/src/target/source/codegen_c_host.h b/src/target/source/codegen_c_host.h index 694104afc0..aeba685f74 100644 --- a/src/target/source/codegen_c_host.h +++ b/src/target/source/codegen_c_host.h @@ -44,8 +44,7 @@ class CodeGenCHost : public CodeGenC { const std::unordered_set<std::string>& devices); void InitGlobalContext(); - void AddFunction(const PrimFunc& f, bool emit_fwd_func_decl = false); - std::string Finish() final; + void AddFunction(const GlobalVar& gvar, const PrimFunc& f, bool emit_fwd_func_decl = false); /*! * \brief Add functions from the (unordered) range to the current module in a deterministic * order. This helps with debugging. diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index a91f8b0164..7639ce6065 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -75,7 +75,7 @@ class ThreadIdxExtractor : public tir::StmtVisitor { PrimExpr threadIdx_z_ext = Integer(1); }; -void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) { +void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f, std::ostream& os) { ThreadIdxExtractor extractor; extractor(f->body); arith::Analyzer analyzer; @@ -86,7 +86,7 @@ void CodeGenCUDA::PrintExtraAttrs(const PrimFunc& f) { // unable to extract the number of threads per block, hence directly return return; } - stream << " __launch_bounds__(" << threadIdx_ext_int->value << ")"; + os << " __launch_bounds__(" << threadIdx_ext_int->value << ")"; } } diff --git a/src/target/source/codegen_cuda.h b/src/target/source/codegen_cuda.h index 3ec0c3bc2d..bc7b34b500 100644 --- a/src/target/source/codegen_cuda.h +++ b/src/target/source/codegen_cuda.h @@ -47,7 +47,7 @@ class CodeGenCUDA final : public CodeGenC { } // override behavior void PrintFuncPrefix(std::ostream& os) final; - void PrintExtraAttrs(const PrimFunc& f) final; + void PrintExtraAttrs(const PrimFunc& f, std::ostream& os) final; // NOLINT(*) void VisitStmt_(const ForNode* op) final; void PrintStorageSync(const CallNode* op) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index b8c30691e2..ebb7566489 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -36,6 +36,8 @@ namespace codegen { void CodeGenMetal::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); + // skip the first underscore, so SSA variable starts from _1 + name_supply_->FreshName("v_"); // analyze the data; for (Var arg : f->params) { if (arg.dtype().is_handle()) { @@ -52,37 +54,33 @@ CodeGenMetal::CodeGenMetal(Target target) : target_(target) { << "};\n\n"; } -void CodeGenMetal::AddFunction(const PrimFunc& f) { - // clear previous generated state. - this->InitFuncState(f); - // skip the first underscore, so SSA variable starts from _1 - name_supply_->FreshName("v_"); - +void CodeGenMetal::PrintFunctionSignature(const String& function_name, const PrimFunc& func, + std::ostream& os) { // add to alloc buffer type. - auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); + auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; // Function header. - this->stream << "kernel void " << static_cast<std::string>(global_symbol.value()) << "("; + os << "kernel void " << static_cast<std::string>(global_symbol.value()) << "("; // Buffer arguments size_t num_buffer = 0; size_t limit = target_->GetAttr<Integer>("max_function_args").value().IntValue(); - if (f->params.size() > limit) { + if (func->params.size() > limit) { LOG(WARNING) << "Probably you won't be able to execute your kernel due to high number of " "buffers in the kernel"; } - for (size_t i = 0; i < f->params.size(); ++i, ++num_buffer) { - Var v = f->params[i]; + for (size_t i = 0; i < func->params.size(); ++i, ++num_buffer) { + Var v = func->params[i]; if (!v.dtype().is_handle()) break; - stream << " "; + os << " "; std::string vid = AllocVarID(v.get()); auto it = alloc_storage_scope_.find(v.get()); if (it != alloc_storage_scope_.end()) { - PrintStorageScope(it->second, stream); + PrintStorageScope(it->second, os); } - PrintType(GetType(v), stream); + PrintType(GetType(v), os); // Register handle data type // TODO(tvm-team): consider simply keep type info in the // type annotation(via a normalizing rewriting). @@ -91,19 +89,18 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { RegisterHandleType(v.get(), prim->dtype); } } - stream << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; + os << ' ' << vid << " [[ buffer(" << i << ") ]],\n"; } // Setup normal arguments. - size_t nargs = f->params.size() - num_buffer; + size_t nargs = func->params.size() - num_buffer; std::string varg = name_supply_->FreshName("arg"); if (nargs != 0) { std::string arg_buf_type = static_cast<std::string>(global_symbol.value()) + "_args_t"; - stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer - << ") ]],\n"; + os << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n"; // declare the struct decl_stream << "struct " << arg_buf_type << " {\n"; - for (size_t i = num_buffer; i < f->params.size(); ++i) { - Var v = f->params[i]; + for (size_t i = num_buffer; i < func->params.size(); ++i) { + Var v = func->params[i]; ICHECK(!v.dtype().is_handle()); std::string vid = AllocVarID(v.get()); std::ostringstream vref; @@ -131,7 +128,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); int work_dim = 0; - auto launch_params = f->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams).value(); + auto launch_params = func->GetAttr<Array<String>>(tir::attr::kKernelLaunchParams).value(); for (const auto& tag : launch_params) { if (tag != runtime::launch_param::kUseDynamicSharedMemoryTag) { runtime::ThreadScope scope = runtime::ThreadScope::Create(tag); @@ -141,22 +138,16 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { if (work_dim != 0) { // use ushort by default for now - stream << " "; - PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); - stream << " blockIdx [[threadgroup_position_in_grid]],\n"; - stream << " "; - PrintType(DataType::UInt(thread_index_bits_, work_dim), stream); - stream << " threadIdx [[thread_position_in_threadgroup]]\n"; + os << " "; + PrintType(DataType::UInt(thread_index_bits_, work_dim), os); + os << " blockIdx [[threadgroup_position_in_grid]],\n"; + os << " "; + PrintType(DataType::UInt(thread_index_bits_, work_dim), os); + os << " threadIdx [[thread_position_in_threadgroup]]\n"; } thread_work_dim_ = work_dim; - // the function scope. - stream << ") {\n"; - int func_scope = this->BeginScope(); - this->PrintStmt(f->body); - this->EndScope(func_scope); - this->PrintIndent(); - this->stream << "}\n\n"; + os << ")"; } void CodeGenMetal::BindThreadIndex(const IterVar& iv) { @@ -342,27 +333,33 @@ runtime::Module BuildMetal(IRModule mod, Target target) { const auto* fmetal_compile = Registry::Get("tvm_callback_metal_compile"); std::string fmt = fmetal_compile ? "metallib" : "metal"; - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc"; - auto global_symbol = kv.second->GetAttr<String>(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()); - std::string func_name = global_symbol.value(); + Map<GlobalVar, PrimFunc> functions; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenMetal: Can only take PrimFunc"; + auto calling_conv = base_func->GetAttr<Integer>(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + + auto prim_func = Downcast<PrimFunc>(base_func); + functions.Set(gvar, prim_func); + } - source_maker << "// Function: " << func_name << "\n"; + for (auto [gvar, prim_func] : functions) { + source_maker << "// Function: " << gvar->name_hint << "\n"; CodeGenMetal cg(target); cg.Init(output_ssa); - auto f = Downcast<PrimFunc>(kv.second); - auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) - << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - cg.AddFunction(f); + for (auto [other_gvar, other_prim_func] : functions) { + cg.DeclareFunction(other_gvar, other_prim_func); + } + cg.AddFunction(gvar, prim_func); + std::string fsource = cg.Finish(); source_maker << fsource << "\n"; if (fmetal_compile) { fsource = (*fmetal_compile)(fsource, target).operator std::string(); } - smap[func_name] = fsource; + smap[cg.GetFunctionName(gvar)] = fsource; } return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str()); diff --git a/src/target/source/codegen_metal.h b/src/target/source/codegen_metal.h index 36be10d163..26c991e60d 100644 --- a/src/target/source/codegen_metal.h +++ b/src/target/source/codegen_metal.h @@ -38,7 +38,8 @@ class CodeGenMetal final : public CodeGenC { explicit CodeGenMetal(Target target); // override print thread tag. void PrintArgUnionDecl(); - void AddFunction(const PrimFunc& f); // NOLINT(*) + void PrintFunctionSignature(const String& function_name, const PrimFunc& func, + std::ostream& os) override; void InitFuncState(const PrimFunc& f) final; void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) void PrintStorageSync(const CallNode* op) final; // NOLINT(*) diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index c15d2253d7..da6a4de619 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -595,18 +595,26 @@ runtime::Module BuildOpenCL(IRModule mod, Target target) { using tvm::runtime::Registry; bool output_ssa = false; + Map<GlobalVar, PrimFunc> functions; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc"; + auto prim_func = Downcast<PrimFunc>(base_func); + auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv); + ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) + << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; + functions.Set(gvar, prim_func); + } + std::stringstream code; const auto* fpostproc = Registry::Get("tvm_callback_opencl_postproc"); - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc"; - code << "// Function: " << kv.first->name_hint << std::endl; + for (auto [gvar, prim_func] : functions) { + code << "// Function: " << gvar->name_hint << std::endl; CodeGenOpenCL cg; cg.Init(output_ssa); - auto f = Downcast<PrimFunc>(kv.second); - auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); - ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) - << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - cg.AddFunction(f); + for (auto [other_gvar, other_prim_func] : functions) { + cg.DeclareFunction(other_gvar, other_prim_func); + } + cg.AddFunction(gvar, prim_func); std::string fsource = cg.Finish(); if (fpostproc) { fsource = (*fpostproc)(fsource, target).operator std::string(); diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 83046de107..aa7a32320c 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -145,13 +145,21 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) { // Generate source code for get_source(). cg.Init(output_ssa); - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only take PrimFunc"; - auto f = Downcast<PrimFunc>(kv.second); - auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); + Map<GlobalVar, PrimFunc> functions; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenVHLS: Can only take PrimFunc"; + auto prim_func = Downcast<PrimFunc>(base_func); + auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - cg.AddFunction(f); + functions.Set(gvar, prim_func); + } + + for (auto [gvar, prim_func] : functions) { + cg.DeclareFunction(gvar, prim_func); + } + for (auto [gvar, prim_func] : functions) { + cg.AddFunction(gvar, prim_func); } std::string whole_code = cg.Finish(); @@ -159,21 +167,21 @@ runtime::Module BuildSDAccel(IRModule mod, Target target) { // Generate source code for compilation. Array<Array<runtime::String>> kernel_info; - for (auto kv : mod->functions) { - ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenOpenCL: Can only take PrimFunc"; - auto f = Downcast<PrimFunc>(kv.second); + for (auto [gvar, prim_func] : functions) { CodeGenVivadoHLS cg; cg.Init(output_ssa); - cg.AddFunction(f); + + for (auto [other_gvar, other_prim_func] : functions) { + cg.DeclareFunction(other_gvar, other_prim_func); + } + cg.AddFunction(gvar, prim_func); std::string code = cg.Finish(); if (const auto* f = runtime::Registry::Get("tvm_callback_vhls_postproc")) { code = (*f)(code, target).operator std::string(); } - auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; - kernel_info.push_back({global_symbol.value(), code}); + auto function_name = cg.GetFunctionName(gvar); + kernel_info.push_back({function_name, code}); } std::string xclbin; diff --git a/src/target/source/codegen_webgpu.cc b/src/target/source/codegen_webgpu.cc index 4d1d834c7f..6a6712a4ce 100644 --- a/src/target/source/codegen_webgpu.cc +++ b/src/target/source/codegen_webgpu.cc @@ -45,6 +45,12 @@ std::string CodeGenWebGPU::Finish() { void CodeGenWebGPU::InitFuncState(const PrimFunc& f) { CodeGenC::InitFuncState(f); + // skip the first underscore, so SSA variable starts from + name_supply_->FreshName("v_"); + // Setup the thread group info. + ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); + ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); + // analyze the data; for (Var arg : f->params) { if (arg.dtype().is_handle()) { @@ -56,28 +62,12 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) { CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {} -void CodeGenWebGPU::AddFunction(const PrimFunc& f) { - // clear previous generated state. - this->InitFuncState(f); - // skip the first underscore, so SSA variable starts from - name_supply_->FreshName("v_"); - // Setup the thread group info. - ICHECK_EQ(name_supply_->FreshName("threadIdx"), "threadIdx"); - ICHECK_EQ(name_supply_->FreshName("blockIdx"), "blockIdx"); - - // add to alloc buffer type. - auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.defined()) - << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; - - decl_stream << "//----------------------------------------\n" - << "// function: " << global_symbol.value() << "\n" - << "//----------------------------------------\n"; - +void CodeGenWebGPU::PrintFunctionSignature(const String& function_name, const PrimFunc& func, + std::ostream& os) { std::vector<Var> pod_args; int num_buffer = 0; // setup buffer argumemts - for (Var arg : f->params) { + for (Var arg : func->params) { DataType t = arg.dtype(); if (t.is_handle()) { auto* ptr = arg->type_annotation.as<PointerTypeNode>(); @@ -111,16 +101,18 @@ void CodeGenWebGPU::AddFunction(const PrimFunc& f) { } // add to alloc buffer type. // Function header. - this->stream << "fn main(\n" - << " @builtin(workgroup_id) blockIdx : vec3<u32>,\n" - << " @builtin(local_invocation_id) threadIdx : vec3<u32>\n" - << ") {\n"; - // the function scope. - int func_scope = this->BeginScope(); - this->PrintStmt(f->body); - this->EndScope(func_scope); - this->PrintIndent(); - this->stream << "}\n\n"; + os << "fn main(\n" + << " @builtin(workgroup_id) blockIdx : vec3<u32>,\n" + << " @builtin(local_invocation_id) threadIdx : vec3<u32>\n" + << ")"; +} + +void CodeGenWebGPU::AddFunction(const GlobalVar& gvar, const PrimFunc& func) { + CodeGenC::AddFunction(gvar, func); + decl_stream << "//----------------------------------------\n" + << "// function: " << GetFunctionName(gvar) << "\n" + << "//----------------------------------------\n"; + // anotate workgroup this->fwd_decl_stream << "@compute @workgroup_size(" << workgroup_size_[0] << ", " << workgroup_size_[1] << ", " << workgroup_size_[2] << ")\n"; @@ -524,22 +516,31 @@ runtime::Module BuildWebGPU(IRModule mod, Target target) { mod = tir::transform::PointerValueTypeRewrite()(std::move(mod)); bool output_ssa = false; - std::unordered_map<std::string, std::string> smap; - for (auto kv : mod->functions) { - CodeGenWebGPU cg(target); - ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only take PrimFunc"; - auto f = Downcast<PrimFunc>(kv.second); - auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv); + Map<GlobalVar, PrimFunc> functions; + for (auto [gvar, base_func] : mod->functions) { + ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenWebGPU: Can only take PrimFunc"; + auto prim_func = Downcast<PrimFunc>(base_func); + auto calling_conv = prim_func->GetAttr<Integer>(tvm::attr::kCallingConv); ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenWebGPU: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; - auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); + auto global_symbol = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol); ICHECK(global_symbol.defined()) << "CodeGenWebGPU: Expect PrimFunc to have the global_symbol attribute"; - std::string f_name = global_symbol.value(); + functions.Set(gvar, prim_func); + } + + std::unordered_map<std::string, std::string> smap; + for (auto [gvar, prim_func] : functions) { + CodeGenWebGPU cg(target); cg.Init(output_ssa); - cg.AddFunction(f); + + for (auto [other_gvar, other_prim_func] : functions) { + cg.DeclareFunction(other_gvar, other_prim_func); + } + cg.AddFunction(gvar, prim_func); + std::string code = cg.Finish(); - smap[f_name] = code; + smap[cg.GetFunctionName(gvar)] = code; } auto n = make_object<WebGPUSourceModuleNode>(smap, ExtractFuncInfo(mod)); return runtime::Module(n); diff --git a/src/target/source/codegen_webgpu.h b/src/target/source/codegen_webgpu.h index 57f226ba8a..6ae942a3ad 100644 --- a/src/target/source/codegen_webgpu.h +++ b/src/target/source/codegen_webgpu.h @@ -48,7 +48,9 @@ class CodeGenWebGPU final : public CodeGenC { explicit CodeGenWebGPU(Target target); // overrides std::string Finish() final; - void AddFunction(const PrimFunc& f); // NOLINT(*) + void PrintFunctionSignature(const String& function_name, const PrimFunc& func, + std::ostream& os) final; + void AddFunction(const GlobalVar& gvar, const PrimFunc& f) final; void InitFuncState(const PrimFunc& f) final; void PrintStorageSync(const CallNode* op) final; // NOLINT(*) void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index a6f4b5bb3e..90640a6db6 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -613,12 +613,14 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } for (const tir::Var& pool_var : metadata_->pools) { + call_args_ss << "((uint8_t*)"; String pool_name = metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name; if (IsInternalWorkspaceBuffer(pool_var)) { - call_args_ss << "&" << pool_name << ","; + call_args_ss << "&" << pool_name; } else { - call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name) << ","; + call_args_ss << "workspace_pools->" << tvm::runtime::SanitizeName(pool_name); } + call_args_ss << "),"; } for (const String& device : metadata_->devices) { call_args_ss << "devices->" << device << ","; diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc index 39214c4546..fd14f48921 100644 --- a/src/tir/op/op.cc +++ b/src/tir/op/op.cc @@ -70,6 +70,32 @@ Type GetType(const PrimExpr& expr) { return ptr->type_annotation; } } + + if (auto* access = expr.as<tir::CallNode>()) { + if (access->op.same_as(builtin::tvm_access_ptr())) { + ICHECK(access->args.size()) << "Builtin tvm_access_ptr() may not have empty arguments"; + auto type_annotation = Downcast<Call>(access->args[0]); + static auto builtin_op = Op::Get("tir.type_annotation"); + ICHECK(type_annotation->op.same_as(builtin_op)) + << "Expected the first argument of builtin tvm_access_ptr() " + << "to be a type annotation, but found " << type_annotation->op; + return PointerType(PrimType(type_annotation->dtype)); + } + } + + if (auto* address_of = expr.as<tir::CallNode>()) { + if (address_of->op.same_as(builtin::address_of())) { + ICHECK_EQ(address_of->args.size(), 1) + << "Builtin address_of() expects a single argument, but received arguments " + << address_of->args; + auto* address = address_of->args[0].as<BufferLoadNode>(); + ICHECK(address) + << "Builtin address_of() expects the argument to be a BufferLoad, but received argument " + << address_of->args[0]; + + return PointerType(PrimType(address->dtype)); + } + } // Default: return the type indicated by the dtype. runtime::DataType dtype = expr.dtype(); return GetTypeFromRuntimeDataType(dtype); diff --git a/tests/python/relay/aot/test_crt_forward_declarations.py b/tests/python/relay/aot/test_crt_forward_declarations.py index e001a62ab9..99e2f0c923 100644 --- a/tests/python/relay/aot/test_crt_forward_declarations.py +++ b/tests/python/relay/aot/test_crt_forward_declarations.py @@ -33,8 +33,6 @@ from tvm.micro.testing.aot_test_utils import ( AOTTestRunner, ) -pytestmark = pytest.mark.skip(reason="regression introduced in #15725") - def _change_ndarray_layout(arr, src_layout, dst_layout): """Makes a copy of an ndarray, reshaping it to a new data layout. @@ -162,8 +160,8 @@ def test_internal_calls(interface_api, use_unpacked_api, test_runner): lib_mod = compiled_models[0].executor_factory.lib.imported_modules[0] main_source = lib_mod.get_source() - assert main_source.count("int32_t tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 1 - assert main_source.count("int32_t tvmgen_default_fused_layout_transform") == 3 + assert main_source.count("int32_t tvmgen_default_fused_nn_contrib_depthwise_conv2d_NCHWc") == 2 + assert main_source.count("int32_t tvmgen_default_fused_layout_transform") == 6 @tvm.testing.requires_corstone300 diff --git a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py index 7bea7577b6..f6145cd1c5 100644 --- a/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py +++ b/tests/python/topi/python/test_topi_conv2d_tensordot_opts.py @@ -135,8 +135,13 @@ def test_write_3x3_depthwise_code(): #define TENSORDOT_OPT_X1_INT16_W48_3X3_000_EXISTS #include <arm_acle.h> __attribute__((always_inline)) static inline int32_t tensordot_opt_x1_int16_w48_3x3_000( - int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale + int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, + int32_t *bias, int32_t *scale ) { + int32_t *output = output_arg; + int32_t *tensor = tensor_arg; + int32_t *kernel = kernel_arg; + int32_t sum_0 = *bias; int32_t tensor__y00_x00__y00_x01 = tensor[0]; @@ -188,8 +193,13 @@ def test_odd_width_3x3_depthwise_strides_code(): #define TENSORDOT_OPT_X2_INT16_W49_3X3_000_2_4_EXISTS #include <arm_acle.h> __attribute__((always_inline)) static inline int32_t tensordot_opt_x2_int16_w49_3x3_000_2_4( - int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale + int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, + int32_t *bias, int32_t *scale ) { + int32_t *output = output_arg; + int32_t *tensor = tensor_arg; + int32_t *kernel = kernel_arg; + int32_t sum_0 = *bias, sum_1 = *bias; int32_t tensor__y00_x00__y00_x01 = tensor[0]; @@ -251,8 +261,13 @@ def test_1x1x8_convolution_code(): #define TENSORDOT_OPT_X4_INT16_W384_1X8_000_8_1_EXISTS #include <arm_acle.h> __attribute__((always_inline)) static inline int32_t tensordot_opt_x4_int16_w384_1x8_000_8_1( - int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale + int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, + int32_t *bias, int32_t *scale ) { + int32_t *output = output_arg; + int32_t *tensor = tensor_arg; + int32_t *kernel = kernel_arg; + int32_t sum_0 = *bias, sum_1 = *bias, sum_2 = *bias, sum_3 = *bias; int32_t tensor__y00_x00__y00_x01 = tensor[0]; @@ -349,8 +364,13 @@ def test_3x3x3_offset_convolution_code(): #define TENSORDOT_OPT_X1_INT16_W288_3X9_111_EXISTS #include <arm_acle.h> __attribute__((always_inline)) static inline int32_t tensordot_opt_x1_int16_w288_3x9_111( - int32_t *output, int32_t *tensor, int32_t *kernel, int32_t *bias, int32_t *scale + int16_t *output_arg, int16_t *tensor_arg, int16_t *kernel_arg, + int32_t *bias, int32_t *scale ) { + int32_t *output = output_arg; + int32_t *tensor = tensor_arg; + int32_t *kernel = kernel_arg; + int32_t sum_0 = *bias; int32_t tensor__unknown__y00_x00 = tensor[0]; diff --git a/tests/python/unittest/test_target_codegen_c_host.py b/tests/python/unittest/test_target_codegen_c_host.py index d02f8744f1..3aca0fc8c7 100644 --- a/tests/python/unittest/test_target_codegen_c_host.py +++ b/tests/python/unittest/test_target_codegen_c_host.py @@ -14,11 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import tvm import tvm.testing + from tvm import te -import numpy as np from tvm.contrib import utils +from tvm.script import tir as T, ir as I + +import numpy as np def test_add(): @@ -228,11 +232,39 @@ def test_call_packed(): check_global_packed_func() +def test_subroutine_call(): + @I.ir_module + class mod: + @T.prim_func + def main(A: T.Buffer(1, dtype="float32")): + mod.subroutine(A.data) + + @T.prim_func(private=True) + def subroutine(A_data: T.handle("float32")): + A = T.decl_buffer(1, dtype="float32", data=A_data) + A[0] = 42.0 + + built = tvm.build(mod, target="c") + + func_names = list(built["get_func_names"]()) + assert ( + "main" in func_names + ), "Externally exposed functions should be listed in available functions." + assert ( + "subroutine" not in func_names + ), "Internal function should not be listed in available functions." + + source = built.get_source() + assert ( + source.count("main(void*") == 2 + ), "Expected two occurrences, for forward-declaration and definition" + assert ( + source.count("subroutine(float*") == 2 + ), "Expected two occurrences, for forward-declaration and definition" + assert ( + source.count("subroutine(") == 3 + ), "Expected three occurrences, for forward-declaration, definition, and call from main." + + if __name__ == "__main__": - test_add() - test_add_pipeline() - test_reinterpret() - test_ceil() - test_floor() - test_round() - test_call_packed() + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index 588a92d87c..61f0892a9c 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -268,6 +268,7 @@ cast_smem_ptr_to_int(const void* const smem_ptr) #define int64_t long long #define uint64_t unsigned long long #endif +extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C); extern "C" __global__ void __launch_bounds__(16) main_kernel(float* __restrict__ A, float* __restrict__ B, float* __restrict__ C) { __shared__ float A_shared[64]; __shared__ float B_shared[64];