This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main-mod in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 4f5f2d9571f2ec8dde7ee5570752c20cafbcc256 Author: tqchen <tianqi.tc...@gmail.com> AuthorDate: Tue Aug 12 10:22:42 2025 -0400 [REFATOR] Phase out entry_func This PR phases out the entry function. Previously module itself have a default entry function. Supporting such feature requires extra duplicated code and indirection. Most of our use cases can move toward explicitly naming the function as mod["main"], so this PR phases out this behavior. --- include/tvm/runtime/module.h | 2 - include/tvm/tir/function.h | 10 ----- include/tvm/tir/transform.h | 6 --- jvm/core/src/main/java/org/apache/tvm/Module.java | 2 +- python/tvm/runtime/executable.py | 4 -- python/tvm/runtime/module.py | 22 ----------- python/tvm/tir/build.py | 3 +- python/tvm/tir/pipeline.py | 1 - python/tvm/tir/transform/transform.py | 11 ------ src/meta_schedule/arg_info.cc | 10 ++--- src/runtime/cuda/cuda_module.cc | 2 +- src/runtime/library_module.cc | 6 +-- src/runtime/metal/metal_module.mm | 2 +- src/runtime/opencl/opencl_module.cc | 2 +- src/runtime/rocm/rocm_module.cc | 2 +- src/runtime/vulkan/vulkan_wrapped_func.cc | 2 +- src/target/llvm/codegen_cpu.cc | 4 +- src/target/llvm/codegen_hexagon.cc | 11 ------ src/target/llvm/llvm_module.cc | 22 ++--------- src/target/source/codegen_c_host.cc | 14 ------- src/tir/ir/transform.cc | 1 - src/tir/transforms/primfunc_utils.cc | 43 +--------------------- tests/python/codegen/test_target_codegen_device.py | 2 +- .../test_hexagon/test_async_dma_pipeline.py | 10 +++-- .../contrib/test_hexagon/test_parallel_hvx.py | 2 +- .../test_hexagon/test_parallel_hvx_load_vtcm.py | 8 +--- .../contrib/test_hexagon/test_parallel_scalar.py | 4 +- .../contrib/test_hexagon/test_vtcm_bandwidth.py | 8 +++- .../test_runtime_builtin_kv_cache_transfer.py | 2 +- ...runtime_builtin_paged_attention_kv_cache_cpu.py | 2 +- ..._builtin_paged_attention_kv_cache_flashinfer.py | 2 +- ...ltin_paged_attention_kv_cache_mla_flashinfer.py | 2 +- ...ime_builtin_paged_attention_kv_cache_mla_tir.py | 2 +- ...runtime_builtin_paged_attention_kv_cache_tir.py | 2 +- .../python/relax/test_runtime_builtin_rnn_state.py | 2 +- .../tir-transform/test_tir_transform_helpers.py | 31 ---------------- tests/python/tvmscript/test_tvmscript_roundtrip.py | 2 - 37 files changed, 44 insertions(+), 219 deletions(-) diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h index efbaa6508a..05b57de39d 100644 --- a/include/tvm/runtime/module.h +++ b/include/tvm/runtime/module.h @@ -296,8 +296,6 @@ constexpr const char* tvm_set_device = "__tvm_set_device"; constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state"; /*! \brief Prepare the global barrier before kernels that uses global barrier. */ constexpr const char* tvm_prepare_global_barrier = "__tvm_prepare_global_barrier"; -/*! \brief Placeholder for the module's entry function. */ -constexpr const char* tvm_module_main = "__tvm_main__"; } // namespace symbol // implementations of inline functions. diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h index 6ea50e9ae0..ff9a6ff927 100644 --- a/include/tvm/tir/function.h +++ b/include/tvm/tir/function.h @@ -316,16 +316,6 @@ constexpr const char* kKernelLaunchParams = "tir.kernel_launch_params"; */ constexpr const char* kNoAlias = "tir.noalias"; -/*! - * \brief Mark the function as the entry function of - * the final generated runtime module. - * - * Type: Integer - * - * \note There can only be one entry function per module. - */ -constexpr const char* kIsEntryFunc = "tir.is_entry_func"; - /*! * \brief Mark the function as the global function called from the host. * diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index eb64d87f95..c7af05e7f2 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -703,12 +703,6 @@ TVM_DLL Pass RenormalizeSplitPattern(); */ TVM_DLL Pass BindTarget(Target target); -/*! - * \brief Set a PrimFunc as the entry point if it is only function in IRModule. - * \return The pass. - */ -TVM_DLL Pass AnnotateEntryFunc(); - /*! * \brief Filter PrimFuncs with a given condition. * \return The pass. diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java b/jvm/core/src/main/java/org/apache/tvm/Module.java index 5e78e26ae7..9fa65054f9 100644 --- a/jvm/core/src/main/java/org/apache/tvm/Module.java +++ b/jvm/core/src/main/java/org/apache/tvm/Module.java @@ -46,7 +46,7 @@ public class Module extends TVMObject { } private Function entry = null; - private final String entryName = "__tvm_main__"; + private final String entryName = "__tvm_ffi_main__"; /** diff --git a/python/tvm/runtime/executable.py b/python/tvm/runtime/executable.py index b6e13a65a9..a1a6606765 100644 --- a/python/tvm/runtime/executable.py +++ b/python/tvm/runtime/executable.py @@ -36,10 +36,6 @@ class Executable: """Get the PackedFunc from the jitted module.""" return self.jit().get_function(name, query_imports=True) - def __call__(self, *args, **kwargs) -> Any: - """Call the executable.""" - return self.jit().entry_func(*args, **kwargs) - def jit( self, *, diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 3dd4de5da0..30f83474dc 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -103,24 +103,8 @@ class Module(tvm.ffi.Object): def __new__(cls): instance = super(Module, cls).__new__(cls) # pylint: disable=no-value-for-parameter - instance.entry_name = "__tvm_main__" - instance._entry = None return instance - @property - def entry_func(self): - """Get the entry function - - Returns - ------- - f : tvm.runtime.PackedFunc - The entry function if exist - """ - if self._entry: - return self._entry - self._entry = self.get_function("__tvm_main__") - return self._entry - def implements_function(self, name, query_imports=False): """Returns True if the module has a definition for the global function with name. Note that has_function(name) does not imply get_function(name) is non-null since the module @@ -179,12 +163,6 @@ class Module(tvm.ffi.Object): raise ValueError("Can only take string as function name") return self.get_function(name) - def __call__(self, *args): - if self._entry: - return self._entry(*args) - # pylint: disable=not-callable - return self.entry_func(*args) - @property def type_key(self): """Get type key of the module.""" diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py index 98e549cc9c..431c601e72 100644 --- a/python/tvm/tir/build.py +++ b/python/tvm/tir/build.py @@ -80,8 +80,7 @@ def split_host_device_mods(mod: IRModule) -> Tuple[IRModule, Dict[Target, IRModu @T.prim_func def main(self_handle: T.handle, args: T.handle, num_args: T.int32, result: T.handle): T.func_attr({"target": T.target({"keys": ["cpu"], "kind": "c"}), - "calling_conv": 1, # kCPackedFunc for entry functions - "tir.is_entry_func": True}) + "calling_conv": 1}) # ... main function implementation The function will return: diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py index ae78b05738..1082cd8fac 100644 --- a/python/tvm/tir/pipeline.py +++ b/python/tvm/tir/pipeline.py @@ -89,7 +89,6 @@ def default_tir_pipeline(): tir.transform.VerifyVTCMLimit(), tir.transform.LowerVtcmAlloc(), tir.transform.VerifyMemory(), - tir.transform.AnnotateEntryFunc(), ] ) if bool(config.get("tir.detect_global_barrier", False)): diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 93a182ca3b..178a203ca5 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -1018,17 +1018,6 @@ def BindTarget(target): return _ffi_api.BindTarget(target) # type: ignore -def AnnotateEntryFunc(): - """Set a PrimFunc as the entry point if it is only function in IRModule. - - Returns - ------- - fpass : tvm.transform.Pass - The result pass - """ - return _ffi_api.AnnotateEntryFunc() # type: ignore - - def Filter(fcond: Callable): """Filter out PrimFuncs that does not satisfy the given condition. `fcond` should be a function that takes a primfunc and returns boolean. diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc index 9c2ba084ad..c46a0bf280 100644 --- a/src/meta_schedule/arg_info.cc +++ b/src/meta_schedule/arg_info.cc @@ -25,12 +25,12 @@ namespace meta_schedule { /*! * \brief Find the entry function of the given IRModule, i.e, functions marked by - * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc. + * whose name is `main` or being the only PrimeFunc. * \param mod The IRModule to find the entry function. * \return The entry function. */ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { - // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc` + // Priority 1: PrimFunc marked as `main` int num_prim_func = 0; const tir::PrimFuncNode* main_func = nullptr; const tir::PrimFuncNode* last_func = nullptr; @@ -39,9 +39,6 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { BaseFunc base_func = kv.second; if (const auto* func = base_func.as<tir::PrimFuncNode>()) { last_func = func; - if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - return GetRef<tir::PrimFunc>(func); - } if (gv->name_hint == "main") { main_func = func; } @@ -57,8 +54,7 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) { LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: " << mod; } if (num_prim_func > 1) { - LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are " - "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`" + LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are main" << mod; } return GetRef<tir::PrimFunc>(last_func); diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index 2435cccf0a..8d1ed2b2d1 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -258,7 +258,7 @@ class CUDAPrepGlobalBarrier { ffi::Function CUDAModuleNode::GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); - ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; + ICHECK_NE(name, symbol::tvm_ffi_main) << "Device function do not have main"; if (name == symbol::tvm_prepare_global_barrier) { return ffi::Function(CUDAPrepGlobalBarrier(this, sptr_to_self)); } diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc index fffac4adea..77a1072c33 100644 --- a/src/runtime/library_module.cc +++ b/src/runtime/library_module.cc @@ -50,11 +50,11 @@ class LibraryModuleNode final : public ModuleNode { ffi::Function GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) final { TVMFFISafeCallType faddr; - if (name == runtime::symbol::tvm_module_main) { + if (name == runtime::symbol::tvm_ffi_main) { const char* entry_name = - reinterpret_cast<const char*>(lib_->GetSymbol(runtime::symbol::tvm_module_main)); + reinterpret_cast<const char*>(lib_->GetSymbol(runtime::symbol::tvm_ffi_main)); ICHECK(entry_name != nullptr) - << "Symbol " << runtime::symbol::tvm_module_main << " is not presented"; + << "Symbol " << runtime::symbol::tvm_ffi_main << " is not presented"; faddr = reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(entry_name)); } else { faddr = reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(name.c_str())); diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index be36e6197f..a054378e8f 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -264,7 +264,7 @@ ffi::Function MetalModuleNode::GetFunction(const String& name, ffi::Function ret; AUTORELEASEPOOL { ICHECK_EQ(sptr_to_self.get(), this); - ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; + ICHECK_NE(name, symbol::tvm_ffi_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) { ret = ffi::Function(); diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index 19b426d4b4..57621b8609 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -138,7 +138,7 @@ cl::OpenCLWorkspace* OpenCLModuleNodeBase::GetGlobalWorkspace() { ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); - ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; + ICHECK_NE(name, symbol::tvm_ffi_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return ffi::Function(); const FunctionInfo& info = it->second; diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 791e4b1569..8a71bcbbf7 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -195,7 +195,7 @@ class ROCMWrappedFunc { ffi::Function ROCMModuleNode::GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); - ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; + ICHECK_NE(name, symbol::tvm_ffi_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return ffi::Function(); const FunctionInfo& info = it->second; diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index f4922a1bf0..d863b02cdf 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -208,7 +208,7 @@ VulkanModuleNode::~VulkanModuleNode() { ffi::Function VulkanModuleNode::GetFunction(const String& name, const ObjectPtr<Object>& sptr_to_self) { ICHECK_EQ(sptr_to_self.get(), this); - ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have main"; + ICHECK_NE(name, symbol::tvm_ffi_main) << "Device function do not have main"; auto it = fmap_.find(name); if (it == fmap_.end()) return ffi::Function(); const FunctionInfo& info = it->second; diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 4dd24026c0..a04273094c 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -234,7 +234,7 @@ void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 1); llvm::GlobalVariable* global = new llvm::GlobalVariable(*module_, type, true, llvm::GlobalValue::WeakAnyLinkage, nullptr, - runtime::symbol::tvm_module_main); + runtime::symbol::tvm_ffi_main); #if TVM_LLVM_VERSION >= 100 global->setAlignment(llvm::Align(1)); #else @@ -243,7 +243,7 @@ void CodeGenCPU::AddMainFunction(const std::string& entry_func_name) { // comdat is needed for windows select any linking to work // set comdat to Any(weak linking) if (llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) { - llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_module_main); + llvm::Comdat* comdat = module_->getOrInsertComdat(runtime::symbol::tvm_ffi_main); comdat->setSelectionKind(llvm::Comdat::Any); global->setComdat(comdat); } diff --git a/src/target/llvm/codegen_hexagon.cc b/src/target/llvm/codegen_hexagon.cc index 6f90da3d8a..3d8ed08eee 100644 --- a/src/target/llvm/codegen_hexagon.cc +++ b/src/target/llvm/codegen_hexagon.cc @@ -483,9 +483,6 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { (void)CallOnce; auto cg = std::make_unique<CodeGenHexagon>(); - - std::string entry_func; - for (auto kv : mod->functions) { if (!kv.second->IsInstance<PrimFuncNode>()) { // (@jroesch): we relax constraints here, relax functions will just be ignored. @@ -493,18 +490,10 @@ runtime::Module BuildHexagon(IRModule mod, Target target) { continue; } auto f = Downcast<PrimFunc>(kv.second); - if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); - ICHECK(global_symbol.has_value()); - entry_func = global_symbol.value(); - } } cg->Init("TVMHexagonModule", llvm_target.get(), std::nullopt, false, false); cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end()); - if (entry_func.length() != 0) { - cg->AddMainFunction(entry_func); - } // Uncomment to get the LLVM module right out of codegen, before optimizations. // std::cerr << "HexagonModule.0 {\n" << *cg->GetModulePtr() << "}\n"; diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index a9e09652ee..45ede6efef 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -190,15 +190,7 @@ ffi::Function LLVMModuleNode::GetFunction(const String& name, TVMFFISafeCallType faddr; With<LLVMTarget> llvm_target(*llvm_instance_, LLVMTarget::GetTargetMetadata(*module_)); - if (name == runtime::symbol::tvm_module_main) { - const char* entry_name = reinterpret_cast<const char*>( - GetGlobalAddr(runtime::symbol::tvm_module_main, *llvm_target)); - ICHECK(entry_name != nullptr) << "Symbol " << runtime::symbol::tvm_module_main - << " is not presented"; - faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(entry_name, *llvm_target)); - } else { - faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name, *llvm_target)); - } + faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name, *llvm_target)); if (faddr == nullptr) return ffi::Function(); return tvm::runtime::WrapFFIFunction(faddr, sptr_to_self); } @@ -337,15 +329,9 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { } auto f = Downcast<PrimFunc>(kv.second); auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); - bool is_entry_func = f->HasNonzeroAttr(tir::attr::kIsEntryFunc); - - ICHECK(global_symbol || !is_entry_func) << "The entry func must be exposed externally."; if (global_symbol) { function_names_.push_back(global_symbol.value()); - if (is_entry_func) { - entry_func = global_symbol.value(); - } } } // TODO(@jroesch): follow up on this condition. @@ -355,11 +341,9 @@ void LLVMModuleNode::Init(const IRModule& mod, const Target& target) { cg->Init("TVMMod", llvm_target.get(), system_lib_prefix, system_lib_prefix.has_value(), false); cg->SetFastMathFlags(llvm_target->GetFastMathFlags()); cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end()); - if (entry_func.length() != 0) { - cg->AddMainFunction(entry_func); - } + q - module_owning_ptr_ = cg->Finish(); + module_owning_ptr_ = cg->Finish(); module_ = module_owning_ptr_.get(); jit_engine_ = llvm_target->GetJITEngine(); llvm_target->SetTargetMetadata(module_); diff --git a/src/target/source/codegen_c_host.cc b/src/target/source/codegen_c_host.cc index 6cd12a9319..1c8a3dd2ea 100644 --- a/src/target/source/codegen_c_host.cc +++ b/src/target/source/codegen_c_host.cc @@ -72,20 +72,6 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const PrimFunc& func, emit_fwd_func_decl_ = emit_fwd_func_decl; CodeGenC::AddFunction(gvar, func); - if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - ICHECK(global_symbol.has_value()) - << "CodeGenCHost: The entry func must have the global_symbol attribute, " - << "but function " << gvar << " only has attributes " << func->attrs; - - function_names_.push_back(runtime::symbol::tvm_module_main); - stream << "// CodegenC: NOTE: Auto-generated entry function\n"; - PrintFuncPrefix(stream); - PrintType(func->ret_type, stream); - stream << " " << tvm::runtime::symbol::tvm_module_main - << "(void* self, void* args,int num_args, void* result) {\n"; - stream << " return " << global_symbol.value() << "(self, args, num_args, result);\n"; - stream << "}\n"; - } } void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol, diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc index aafe6277e2..6ef7ffdca9 100644 --- a/src/tir/ir/transform.cc +++ b/src/tir/ir/transform.cc @@ -42,7 +42,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); -TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>); TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); diff --git a/src/tir/transforms/primfunc_utils.cc b/src/tir/transforms/primfunc_utils.cc index b1f3476eab..99cf901377 100644 --- a/src/tir/transforms/primfunc_utils.cc +++ b/src/tir/transforms/primfunc_utils.cc @@ -29,45 +29,6 @@ namespace tvm { namespace tir { namespace transform { -transform::Pass AnnotateEntryFunc() { - auto fpass = [](IRModule mod, transform::PassContext ctx) -> IRModule { - // If only a single function exists, that function must be the entry - if (mod->functions.size() == 1) { - auto [gvar, base_func] = *mod->functions.begin(); - if (!base_func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { - if (auto ptr = base_func.as<PrimFuncNode>()) { - mod->Update(gvar, WithAttr(GetRef<PrimFunc>(ptr), tir::attr::kIsEntryFunc, true)); - } - } - return mod; - } - - // If the module has multiple functions, but only one is exposed - // externally, that function must be the entry. - bool has_external_non_primfuncs = false; - IRModule with_annotations; - for (const auto& [gvar, base_func] : mod->functions) { - bool is_external = base_func->GetAttr<String>(tvm::attr::kGlobalSymbol).has_value(); - if (is_external) { - if (auto ptr = base_func.as<PrimFuncNode>()) { - with_annotations->Add(gvar, - WithAttr(GetRef<PrimFunc>(ptr), tir::attr::kIsEntryFunc, true)); - } else { - has_external_non_primfuncs = true; - } - } - } - if (with_annotations->functions.size() == 1 && !has_external_non_primfuncs) { - mod->Update(with_annotations); - return mod; - } - - // Default fallback, no annotations may be inferred. - return mod; - }; - return tvm::transform::CreateModulePass(fpass, 0, "tir.AnnotateEntryFunc", {}); -} - transform::Pass Filter(ffi::TypedFunction<bool(PrimFunc)> fcond) { auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext ctx) { if (fcond(f)) { @@ -81,9 +42,7 @@ transform::Pass Filter(ffi::TypedFunction<bool(PrimFunc)> fcond) { TVM_FFI_STATIC_INIT_BLOCK({ namespace refl = tvm::ffi::reflection; - refl::GlobalDef() - .def("tir.transform.AnnotateEntryFunc", AnnotateEntryFunc) - .def("tir.transform.Filter", Filter); + refl::GlobalDef().def("tir.transform.Filter", Filter); }); } // namespace transform diff --git a/tests/python/codegen/test_target_codegen_device.py b/tests/python/codegen/test_target_codegen_device.py index 4dad03d700..0089e0bea6 100644 --- a/tests/python/codegen/test_target_codegen_device.py +++ b/tests/python/codegen/test_target_codegen_device.py @@ -95,7 +95,7 @@ def test_add_pipeline(): dev = tvm.device(device, 0) target = tvm.target.Target(device, host) mhost = tvm.tir.build(sch.mod, target=target) - f = mhost.entry_func + f = mhost["main"] # launch the kernel. n = 1027 a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py index fe7a615531..808c016cf8 100644 --- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py +++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -""" Test different strategies for loading data into vtcm before running HVX workloads. """ +"""Test different strategies for loading data into vtcm before running HVX workloads.""" import numpy as np import pytest @@ -289,9 +289,13 @@ def evaluate( if tvm.testing.utils.IS_IN_CI: # Run with reduced number and repeat for CI - timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=1, repeat=1) + timer = module.time_evaluator( + "__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1 + ) else: - timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=10, repeat=10) + timer = module.time_evaluator( + "__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10 + ) time = timer(a_hexagon, b_hexagon, c_hexagon) if expected_output is not None: diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py b/tests/python/contrib/test_hexagon/test_parallel_hvx.py index 8f77fa1c40..6822352568 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx.py @@ -160,7 +160,7 @@ def evaluate(hexagon_session, shape_dtypes, expected_output_producer, sch): repeat = 1 timer = module.time_evaluator( - "__tvm_main__", hexagon_session.device, number=number, repeat=repeat + "__tvm_ffi_main__", hexagon_session.device, number=number, repeat=repeat ) runtime = timer(a_hexagon, b_hexagon, c_hexagon) tvm.testing.assert_allclose(c_hexagon.numpy(), expected_output_producer(c_shape, a, b)) diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py index a584997dd5..63a65f3716 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py @@ -330,9 +330,7 @@ def setup_and_run(hexagon_session, sch, a, b, c, operations, mem_scope="global") number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) time = timer(a_hexagon, b_hexagon, c_hexagon) gops = round(operations * 128 * 3 / time.mean / 1e9, 4) return gops, c_hexagon.numpy() @@ -364,9 +362,7 @@ def setup_and_run_preallocated(hexagon_session, sch, a, b, c, operations): number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) time = timer(a_hexagon, b_hexagon, c_hexagon, a_vtcm_hexagon, b_vtcm_hexagon, c_vtcm_hexagon) gops = round(operations * 128 * 3 / time.mean / 1e9, 4) return gops, c_hexagon.numpy() diff --git a/tests/python/contrib/test_hexagon/test_parallel_scalar.py b/tests/python/contrib/test_hexagon/test_parallel_scalar.py index bd9c78d5da..5c8043fdff 100644 --- a/tests/python/contrib/test_hexagon/test_parallel_scalar.py +++ b/tests/python/contrib/test_hexagon/test_parallel_scalar.py @@ -104,9 +104,7 @@ def evaluate(hexagon_session, operations, expected, sch): number = 1 repeat = 1 - timer = module.time_evaluator( - "__tvm_main__", hexagon_session.device, number=number, repeat=repeat - ) + timer = module.time_evaluator("main", hexagon_session.device, number=number, repeat=repeat) runtime = timer(a_hexagon, b_hexagon, c_hexagon) tvm.testing.assert_allclose(c_hexagon.numpy(), expected(a, b)) diff --git a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py index 931f99b2ec..265f2bf5fd 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py +++ b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py @@ -108,9 +108,13 @@ def evaluate(hexagon_session, sch, size): if tvm.testing.utils.IS_IN_CI: # Run with reduced number and repeat for CI - timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=1, repeat=1) + timer = module.time_evaluator( + "__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1 + ) else: - timer = module.time_evaluator("__tvm_main__", hexagon_session.device, number=10, repeat=10) + timer = module.time_evaluator( + "__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10 + ) runtime = timer(a_hexagon, a_vtcm_hexagon) diff --git a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py index 81acf5ee86..0d2f445cb8 100644 --- a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py +++ b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py @@ -170,7 +170,7 @@ def set_global_func(head_dim, dtype): with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) f = tvm.tir.build(mod["main"], target=target) - builts.append(f.entry_func) + builts.append(f["main"]) ( ftranspose_append, diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py index 1941edeaa7..305fd18f35 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py @@ -140,7 +140,7 @@ def set_global_func(head_dim, dtype): with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) f = tvm.tir.build(mod["main"], target=target) - builts.append(f.entry_func) + builts.append(f["main"]) ( ftranspose_append, diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index ffd3452292..e13ce1ca7b 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -156,7 +156,7 @@ def set_global_func(): with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) f = tvm.tir.build(mod["main"], target=target) - builts.append(f.entry_func) + builts.append(f["main"]) ( ftranspose_append, diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py index 2f726064a7..53044a786c 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py @@ -169,7 +169,7 @@ def set_global_func(dtype): with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) f = tvm.tir.build(mod["main"], target=target) - builts.append(f.entry_func) + builts.append(f["main"]) ( ftranspose_append, diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py index b2982abdb0..73a4d89dad 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py @@ -134,7 +134,7 @@ def set_global_func(dtype): with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) f = tvm.tir.build(mod["main"], target=target) - builts.append(f.entry_func) + builts.append(f["main"]) ( ftranspose_append, diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 8cd3a73740..44169828e2 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -142,7 +142,7 @@ def set_global_func(head_dim, dtype): with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) f = tvm.tir.build(mod["main"], target=target) - builts.append(f.entry_func) + builts.append(f["main"]) ( ftranspose_append, diff --git a/tests/python/relax/test_runtime_builtin_rnn_state.py b/tests/python/relax/test_runtime_builtin_rnn_state.py index 095aba8b83..fe8c19257d 100644 --- a/tests/python/relax/test_runtime_builtin_rnn_state.py +++ b/tests/python/relax/test_runtime_builtin_rnn_state.py @@ -81,7 +81,7 @@ def set_global_func(): with target: mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod) # pylint: disable=not-callable f = tvm.tir.build(mod["main"], target=target) - return f.entry_func + return f["main"] _f_tir_gets, _f_tir_sets = [], [] for state in states: diff --git a/tests/python/tir-transform/test_tir_transform_helpers.py b/tests/python/tir-transform/test_tir_transform_helpers.py index 0bbd0e7160..a7ac11d7a9 100644 --- a/tests/python/tir-transform/test_tir_transform_helpers.py +++ b/tests/python/tir-transform/test_tir_transform_helpers.py @@ -21,27 +21,6 @@ from tvm.script import tir as T, ir as I import tvm.testing -def test_annotate_entry_func_single_primfunc(): - @tvm.script.ir_module - class MockModule: - @T.prim_func(private=True) - def func1(A: T.Buffer((16,), "float32")): - for i in T.serial(16): - if i == 5: - if i == 5: - A[i] = 0.0 - - mod = MockModule - assert mod - assert not mod["func1"].attrs - after = tvm.tir.transform.AnnotateEntryFunc()(mod) - assert ( - after["func1"].attrs - and "tir.is_entry_func" in after["func1"].attrs - and after["func1"].attrs["tir.is_entry_func"] - ) - - # Test module @tvm.script.ir_module class MockModule: @@ -60,16 +39,6 @@ class MockModule: A[i] = 0.0 -@pytest.mark.xfail -def test_annotate_entry_func_multiple_primfunc(): - mod = MockModule - assert mod - assert not mod["func1"].attrs - assert not mod["func2"].attrs - # This should fail - after = tvm.tir.transform.AnnotateEntryFunc()(mod) - - def test_bind_target(): mod = MockModule assert mod diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py b/tests/python/tvmscript/test_tvmscript_roundtrip.py index 0e1b328844..73ca1dad3b 100644 --- a/tests/python/tvmscript/test_tvmscript_roundtrip.py +++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py @@ -196,7 +196,6 @@ def opt_gemm_mod_host(): T.func_attr( { "tir.noalias": True, - "tir.is_entry_func": True, "calling_conv": 1, } ) @@ -2242,7 +2241,6 @@ def opt_conv_tensorcore_mod_host(): { "tir.noalias": True, "global_symbol": "default_function", - "tir.is_entry_func": True, "calling_conv": 1, } )