This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main-mod
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 4f5f2d9571f2ec8dde7ee5570752c20cafbcc256
Author: tqchen <tianqi.tc...@gmail.com>
AuthorDate: Tue Aug 12 10:22:42 2025 -0400

    [REFATOR] Phase out entry_func
    
    This PR phases out the entry function. Previously module itself have a 
default entry function.
    Supporting such feature requires extra duplicated code and indirection.
    
    Most of our use cases can move toward explicitly naming the function as 
mod["main"],
    so this PR phases out this behavior.
---
 include/tvm/runtime/module.h                       |  2 -
 include/tvm/tir/function.h                         | 10 -----
 include/tvm/tir/transform.h                        |  6 ---
 jvm/core/src/main/java/org/apache/tvm/Module.java  |  2 +-
 python/tvm/runtime/executable.py                   |  4 --
 python/tvm/runtime/module.py                       | 22 -----------
 python/tvm/tir/build.py                            |  3 +-
 python/tvm/tir/pipeline.py                         |  1 -
 python/tvm/tir/transform/transform.py              | 11 ------
 src/meta_schedule/arg_info.cc                      | 10 ++---
 src/runtime/cuda/cuda_module.cc                    |  2 +-
 src/runtime/library_module.cc                      |  6 +--
 src/runtime/metal/metal_module.mm                  |  2 +-
 src/runtime/opencl/opencl_module.cc                |  2 +-
 src/runtime/rocm/rocm_module.cc                    |  2 +-
 src/runtime/vulkan/vulkan_wrapped_func.cc          |  2 +-
 src/target/llvm/codegen_cpu.cc                     |  4 +-
 src/target/llvm/codegen_hexagon.cc                 | 11 ------
 src/target/llvm/llvm_module.cc                     | 22 ++---------
 src/target/source/codegen_c_host.cc                | 14 -------
 src/tir/ir/transform.cc                            |  1 -
 src/tir/transforms/primfunc_utils.cc               | 43 +---------------------
 tests/python/codegen/test_target_codegen_device.py |  2 +-
 .../test_hexagon/test_async_dma_pipeline.py        | 10 +++--
 .../contrib/test_hexagon/test_parallel_hvx.py      |  2 +-
 .../test_hexagon/test_parallel_hvx_load_vtcm.py    |  8 +---
 .../contrib/test_hexagon/test_parallel_scalar.py   |  4 +-
 .../contrib/test_hexagon/test_vtcm_bandwidth.py    |  8 +++-
 .../test_runtime_builtin_kv_cache_transfer.py      |  2 +-
 ...runtime_builtin_paged_attention_kv_cache_cpu.py |  2 +-
 ..._builtin_paged_attention_kv_cache_flashinfer.py |  2 +-
 ...ltin_paged_attention_kv_cache_mla_flashinfer.py |  2 +-
 ...ime_builtin_paged_attention_kv_cache_mla_tir.py |  2 +-
 ...runtime_builtin_paged_attention_kv_cache_tir.py |  2 +-
 .../python/relax/test_runtime_builtin_rnn_state.py |  2 +-
 .../tir-transform/test_tir_transform_helpers.py    | 31 ----------------
 tests/python/tvmscript/test_tvmscript_roundtrip.py |  2 -
 37 files changed, 44 insertions(+), 219 deletions(-)

diff --git a/include/tvm/runtime/module.h b/include/tvm/runtime/module.h
index efbaa6508a..05b57de39d 100644
--- a/include/tvm/runtime/module.h
+++ b/include/tvm/runtime/module.h
@@ -296,8 +296,6 @@ constexpr const char* tvm_set_device = "__tvm_set_device";
 constexpr const char* tvm_global_barrier_state = "__tvm_global_barrier_state";
 /*! \brief Prepare the global barrier before kernels that uses global barrier. 
*/
 constexpr const char* tvm_prepare_global_barrier = 
"__tvm_prepare_global_barrier";
-/*! \brief Placeholder for the module's entry function. */
-constexpr const char* tvm_module_main = "__tvm_main__";
 }  // namespace symbol
 
 // implementations of inline functions.
diff --git a/include/tvm/tir/function.h b/include/tvm/tir/function.h
index 6ea50e9ae0..ff9a6ff927 100644
--- a/include/tvm/tir/function.h
+++ b/include/tvm/tir/function.h
@@ -316,16 +316,6 @@ constexpr const char* kKernelLaunchParams = 
"tir.kernel_launch_params";
  */
 constexpr const char* kNoAlias = "tir.noalias";
 
-/*!
- * \brief Mark the function as the entry function of
- *        the final generated runtime module.
- *
- * Type: Integer
- *
- * \note There can only be one entry function per module.
- */
-constexpr const char* kIsEntryFunc = "tir.is_entry_func";
-
 /*!
  * \brief Mark the function as the global function called from the host.
  *
diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h
index eb64d87f95..c7af05e7f2 100644
--- a/include/tvm/tir/transform.h
+++ b/include/tvm/tir/transform.h
@@ -703,12 +703,6 @@ TVM_DLL Pass RenormalizeSplitPattern();
  */
 TVM_DLL Pass BindTarget(Target target);
 
-/*!
- * \brief Set a PrimFunc as the entry point if it is only function in IRModule.
- * \return The pass.
- */
-TVM_DLL Pass AnnotateEntryFunc();
-
 /*!
  * \brief Filter PrimFuncs with a given condition.
  * \return The pass.
diff --git a/jvm/core/src/main/java/org/apache/tvm/Module.java 
b/jvm/core/src/main/java/org/apache/tvm/Module.java
index 5e78e26ae7..9fa65054f9 100644
--- a/jvm/core/src/main/java/org/apache/tvm/Module.java
+++ b/jvm/core/src/main/java/org/apache/tvm/Module.java
@@ -46,7 +46,7 @@ public class Module extends TVMObject {
   }
 
   private Function entry = null;
-  private final String entryName = "__tvm_main__";
+  private final String entryName = "__tvm_ffi_main__";
 
 
   /**
diff --git a/python/tvm/runtime/executable.py b/python/tvm/runtime/executable.py
index b6e13a65a9..a1a6606765 100644
--- a/python/tvm/runtime/executable.py
+++ b/python/tvm/runtime/executable.py
@@ -36,10 +36,6 @@ class Executable:
         """Get the PackedFunc from the jitted module."""
         return self.jit().get_function(name, query_imports=True)
 
-    def __call__(self, *args, **kwargs) -> Any:
-        """Call the executable."""
-        return self.jit().entry_func(*args, **kwargs)
-
     def jit(
         self,
         *,
diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py
index 3dd4de5da0..30f83474dc 100644
--- a/python/tvm/runtime/module.py
+++ b/python/tvm/runtime/module.py
@@ -103,24 +103,8 @@ class Module(tvm.ffi.Object):
 
     def __new__(cls):
         instance = super(Module, cls).__new__(cls)  # pylint: 
disable=no-value-for-parameter
-        instance.entry_name = "__tvm_main__"
-        instance._entry = None
         return instance
 
-    @property
-    def entry_func(self):
-        """Get the entry function
-
-        Returns
-        -------
-        f : tvm.runtime.PackedFunc
-            The entry function if exist
-        """
-        if self._entry:
-            return self._entry
-        self._entry = self.get_function("__tvm_main__")
-        return self._entry
-
     def implements_function(self, name, query_imports=False):
         """Returns True if the module has a definition for the global function 
with name. Note
         that has_function(name) does not imply get_function(name) is non-null 
since the module
@@ -179,12 +163,6 @@ class Module(tvm.ffi.Object):
             raise ValueError("Can only take string as function name")
         return self.get_function(name)
 
-    def __call__(self, *args):
-        if self._entry:
-            return self._entry(*args)
-        # pylint: disable=not-callable
-        return self.entry_func(*args)
-
     @property
     def type_key(self):
         """Get type key of the module."""
diff --git a/python/tvm/tir/build.py b/python/tvm/tir/build.py
index 98e549cc9c..431c601e72 100644
--- a/python/tvm/tir/build.py
+++ b/python/tvm/tir/build.py
@@ -80,8 +80,7 @@ def split_host_device_mods(mod: IRModule) -> Tuple[IRModule, 
Dict[Target, IRModu
             @T.prim_func
             def main(self_handle: T.handle, args: T.handle, num_args: T.int32, 
result: T.handle):
                 T.func_attr({"target": T.target({"keys": ["cpu"], "kind": 
"c"}),
-                            "calling_conv": 1,  # kCPackedFunc for entry 
functions
-                            "tir.is_entry_func": True})
+                            "calling_conv": 1})
                 # ... main function implementation
 
     The function will return:
diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py
index ae78b05738..1082cd8fac 100644
--- a/python/tvm/tir/pipeline.py
+++ b/python/tvm/tir/pipeline.py
@@ -89,7 +89,6 @@ def default_tir_pipeline():
                 tir.transform.VerifyVTCMLimit(),
                 tir.transform.LowerVtcmAlloc(),
                 tir.transform.VerifyMemory(),
-                tir.transform.AnnotateEntryFunc(),
             ]
         )
         if bool(config.get("tir.detect_global_barrier", False)):
diff --git a/python/tvm/tir/transform/transform.py 
b/python/tvm/tir/transform/transform.py
index 93a182ca3b..178a203ca5 100644
--- a/python/tvm/tir/transform/transform.py
+++ b/python/tvm/tir/transform/transform.py
@@ -1018,17 +1018,6 @@ def BindTarget(target):
     return _ffi_api.BindTarget(target)  # type: ignore
 
 
-def AnnotateEntryFunc():
-    """Set a PrimFunc as the entry point if it is only function in IRModule.
-
-    Returns
-    -------
-    fpass : tvm.transform.Pass
-        The result pass
-    """
-    return _ffi_api.AnnotateEntryFunc()  # type: ignore
-
-
 def Filter(fcond: Callable):
     """Filter out PrimFuncs that does not satisfy the given condition.
     `fcond` should be a function that takes a primfunc and returns boolean.
diff --git a/src/meta_schedule/arg_info.cc b/src/meta_schedule/arg_info.cc
index 9c2ba084ad..c46a0bf280 100644
--- a/src/meta_schedule/arg_info.cc
+++ b/src/meta_schedule/arg_info.cc
@@ -25,12 +25,12 @@ namespace meta_schedule {
 
 /*!
  * \brief Find the entry function of the given IRModule, i.e, functions marked 
by
- * `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
+ * whose name is `main` or being the only PrimeFunc.
  * \param mod The IRModule to find the entry function.
  * \return The entry function.
  */
 inline tir::PrimFunc FindEntryFunc(const IRModule& mod) {
-  // Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
+  // Priority 1: PrimFunc marked as `main`
   int num_prim_func = 0;
   const tir::PrimFuncNode* main_func = nullptr;
   const tir::PrimFuncNode* last_func = nullptr;
@@ -39,9 +39,6 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) {
     BaseFunc base_func = kv.second;
     if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
       last_func = func;
-      if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
-        return GetRef<tir::PrimFunc>(func);
-      }
       if (gv->name_hint == "main") {
         main_func = func;
       }
@@ -57,8 +54,7 @@ inline tir::PrimFunc FindEntryFunc(const IRModule& mod) {
     LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: 
" << mod;
   }
   if (num_prim_func > 1) {
-    LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but 
none of them are "
-                  "annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`"
+    LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but 
none of them are main"
                << mod;
   }
   return GetRef<tir::PrimFunc>(last_func);
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index 2435cccf0a..8d1ed2b2d1 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -258,7 +258,7 @@ class CUDAPrepGlobalBarrier {
 ffi::Function CUDAModuleNode::GetFunction(const String& name,
                                           const ObjectPtr<Object>& 
sptr_to_self) {
   ICHECK_EQ(sptr_to_self.get(), this);
-  ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
+  ICHECK_NE(name, symbol::tvm_ffi_main) << "Device function do not have main";
   if (name == symbol::tvm_prepare_global_barrier) {
     return ffi::Function(CUDAPrepGlobalBarrier(this, sptr_to_self));
   }
diff --git a/src/runtime/library_module.cc b/src/runtime/library_module.cc
index fffac4adea..77a1072c33 100644
--- a/src/runtime/library_module.cc
+++ b/src/runtime/library_module.cc
@@ -50,11 +50,11 @@ class LibraryModuleNode final : public ModuleNode {
 
   ffi::Function GetFunction(const String& name, const ObjectPtr<Object>& 
sptr_to_self) final {
     TVMFFISafeCallType faddr;
-    if (name == runtime::symbol::tvm_module_main) {
+    if (name == runtime::symbol::tvm_ffi_main) {
       const char* entry_name =
-          reinterpret_cast<const 
char*>(lib_->GetSymbol(runtime::symbol::tvm_module_main));
+          reinterpret_cast<const 
char*>(lib_->GetSymbol(runtime::symbol::tvm_ffi_main));
       ICHECK(entry_name != nullptr)
-          << "Symbol " << runtime::symbol::tvm_module_main << " is not 
presented";
+          << "Symbol " << runtime::symbol::tvm_ffi_main << " is not presented";
       faddr = 
reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(entry_name));
     } else {
       faddr = 
reinterpret_cast<TVMFFISafeCallType>(lib_->GetSymbol(name.c_str()));
diff --git a/src/runtime/metal/metal_module.mm 
b/src/runtime/metal/metal_module.mm
index be36e6197f..a054378e8f 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -264,7 +264,7 @@ ffi::Function MetalModuleNode::GetFunction(const String& 
name,
   ffi::Function ret;
   AUTORELEASEPOOL {
     ICHECK_EQ(sptr_to_self.get(), this);
-    ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
+    ICHECK_NE(name, symbol::tvm_ffi_main) << "Device function do not have 
main";
     auto it = fmap_.find(name);
     if (it == fmap_.end()) {
       ret = ffi::Function();
diff --git a/src/runtime/opencl/opencl_module.cc 
b/src/runtime/opencl/opencl_module.cc
index 19b426d4b4..57621b8609 100644
--- a/src/runtime/opencl/opencl_module.cc
+++ b/src/runtime/opencl/opencl_module.cc
@@ -138,7 +138,7 @@ cl::OpenCLWorkspace* 
OpenCLModuleNodeBase::GetGlobalWorkspace() {
 ffi::Function OpenCLModuleNodeBase::GetFunction(const String& name,
                                                 const ObjectPtr<Object>& 
sptr_to_self) {
   ICHECK_EQ(sptr_to_self.get(), this);
-  ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
+  ICHECK_NE(name, symbol::tvm_ffi_main) << "Device function do not have main";
   auto it = fmap_.find(name);
   if (it == fmap_.end()) return ffi::Function();
   const FunctionInfo& info = it->second;
diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc
index 791e4b1569..8a71bcbbf7 100644
--- a/src/runtime/rocm/rocm_module.cc
+++ b/src/runtime/rocm/rocm_module.cc
@@ -195,7 +195,7 @@ class ROCMWrappedFunc {
 ffi::Function ROCMModuleNode::GetFunction(const String& name,
                                           const ObjectPtr<Object>& 
sptr_to_self) {
   ICHECK_EQ(sptr_to_self.get(), this);
-  ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
+  ICHECK_NE(name, symbol::tvm_ffi_main) << "Device function do not have main";
   auto it = fmap_.find(name);
   if (it == fmap_.end()) return ffi::Function();
   const FunctionInfo& info = it->second;
diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc 
b/src/runtime/vulkan/vulkan_wrapped_func.cc
index f4922a1bf0..d863b02cdf 100644
--- a/src/runtime/vulkan/vulkan_wrapped_func.cc
+++ b/src/runtime/vulkan/vulkan_wrapped_func.cc
@@ -208,7 +208,7 @@ VulkanModuleNode::~VulkanModuleNode() {
 ffi::Function VulkanModuleNode::GetFunction(const String& name,
                                             const ObjectPtr<Object>& 
sptr_to_self) {
   ICHECK_EQ(sptr_to_self.get(), this);
-  ICHECK_NE(name, symbol::tvm_module_main) << "Device function do not have 
main";
+  ICHECK_NE(name, symbol::tvm_ffi_main) << "Device function do not have main";
   auto it = fmap_.find(name);
   if (it == fmap_.end()) return ffi::Function();
   const FunctionInfo& info = it->second;
diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc
index 4dd24026c0..a04273094c 100644
--- a/src/target/llvm/codegen_cpu.cc
+++ b/src/target/llvm/codegen_cpu.cc
@@ -234,7 +234,7 @@ void CodeGenCPU::AddMainFunction(const std::string& 
entry_func_name) {
   llvm::Type* type = llvm::ArrayType::get(t_char_, entry_func_name.length() + 
1);
   llvm::GlobalVariable* global =
       new llvm::GlobalVariable(*module_, type, true, 
llvm::GlobalValue::WeakAnyLinkage, nullptr,
-                               runtime::symbol::tvm_module_main);
+                               runtime::symbol::tvm_ffi_main);
 #if TVM_LLVM_VERSION >= 100
   global->setAlignment(llvm::Align(1));
 #else
@@ -243,7 +243,7 @@ void CodeGenCPU::AddMainFunction(const std::string& 
entry_func_name) {
   // comdat is needed for windows select any linking to work
   // set comdat to Any(weak linking)
   if 
(llvm_target_->GetOrCreateTargetMachine()->getTargetTriple().isOSWindows()) {
-    llvm::Comdat* comdat = 
module_->getOrInsertComdat(runtime::symbol::tvm_module_main);
+    llvm::Comdat* comdat = 
module_->getOrInsertComdat(runtime::symbol::tvm_ffi_main);
     comdat->setSelectionKind(llvm::Comdat::Any);
     global->setComdat(comdat);
   }
diff --git a/src/target/llvm/codegen_hexagon.cc 
b/src/target/llvm/codegen_hexagon.cc
index 6f90da3d8a..3d8ed08eee 100644
--- a/src/target/llvm/codegen_hexagon.cc
+++ b/src/target/llvm/codegen_hexagon.cc
@@ -483,9 +483,6 @@ runtime::Module BuildHexagon(IRModule mod, Target target) {
   (void)CallOnce;
 
   auto cg = std::make_unique<CodeGenHexagon>();
-
-  std::string entry_func;
-
   for (auto kv : mod->functions) {
     if (!kv.second->IsInstance<PrimFuncNode>()) {
       // (@jroesch): we relax constraints here, relax functions will just be 
ignored.
@@ -493,18 +490,10 @@ runtime::Module BuildHexagon(IRModule mod, Target target) 
{
       continue;
     }
     auto f = Downcast<PrimFunc>(kv.second);
-    if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
-      auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-      ICHECK(global_symbol.has_value());
-      entry_func = global_symbol.value();
-    }
   }
 
   cg->Init("TVMHexagonModule", llvm_target.get(), std::nullopt, false, false);
   cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());
-  if (entry_func.length() != 0) {
-    cg->AddMainFunction(entry_func);
-  }
 
   // Uncomment to get the LLVM module right out of codegen, before 
optimizations.
   // std::cerr << "HexagonModule.0 {\n" << *cg->GetModulePtr() << "}\n";
diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc
index a9e09652ee..45ede6efef 100644
--- a/src/target/llvm/llvm_module.cc
+++ b/src/target/llvm/llvm_module.cc
@@ -190,15 +190,7 @@ ffi::Function LLVMModuleNode::GetFunction(const String& 
name,
 
   TVMFFISafeCallType faddr;
   With<LLVMTarget> llvm_target(*llvm_instance_, 
LLVMTarget::GetTargetMetadata(*module_));
-  if (name == runtime::symbol::tvm_module_main) {
-    const char* entry_name = reinterpret_cast<const char*>(
-        GetGlobalAddr(runtime::symbol::tvm_module_main, *llvm_target));
-    ICHECK(entry_name != nullptr) << "Symbol " << 
runtime::symbol::tvm_module_main
-                                  << " is not presented";
-    faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(entry_name, 
*llvm_target));
-  } else {
-    faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name, 
*llvm_target));
-  }
+  faddr = reinterpret_cast<TVMFFISafeCallType>(GetFunctionAddr(name, 
*llvm_target));
   if (faddr == nullptr) return ffi::Function();
   return tvm::runtime::WrapFFIFunction(faddr, sptr_to_self);
 }
@@ -337,15 +329,9 @@ void LLVMModuleNode::Init(const IRModule& mod, const 
Target& target) {
     }
     auto f = Downcast<PrimFunc>(kv.second);
     auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
-    bool is_entry_func = f->HasNonzeroAttr(tir::attr::kIsEntryFunc);
-
-    ICHECK(global_symbol || !is_entry_func) << "The entry func must be exposed 
externally.";
 
     if (global_symbol) {
       function_names_.push_back(global_symbol.value());
-      if (is_entry_func) {
-        entry_func = global_symbol.value();
-      }
     }
   }
   // TODO(@jroesch): follow up on this condition.
@@ -355,11 +341,9 @@ void LLVMModuleNode::Init(const IRModule& mod, const 
Target& target) {
   cg->Init("TVMMod", llvm_target.get(), system_lib_prefix, 
system_lib_prefix.has_value(), false);
   cg->SetFastMathFlags(llvm_target->GetFastMathFlags());
   cg->AddFunctionsOrdered(mod->functions.begin(), mod->functions.end());
-  if (entry_func.length() != 0) {
-    cg->AddMainFunction(entry_func);
-  }
+  q
 
-  module_owning_ptr_ = cg->Finish();
+      module_owning_ptr_ = cg->Finish();
   module_ = module_owning_ptr_.get();
   jit_engine_ = llvm_target->GetJITEngine();
   llvm_target->SetTargetMetadata(module_);
diff --git a/src/target/source/codegen_c_host.cc 
b/src/target/source/codegen_c_host.cc
index 6cd12a9319..1c8a3dd2ea 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -72,20 +72,6 @@ void CodeGenCHost::AddFunction(const GlobalVar& gvar, const 
PrimFunc& func,
 
   emit_fwd_func_decl_ = emit_fwd_func_decl;
   CodeGenC::AddFunction(gvar, func);
-  if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
-    ICHECK(global_symbol.has_value())
-        << "CodeGenCHost: The entry func must have the global_symbol 
attribute, "
-        << "but function " << gvar << " only has attributes " << func->attrs;
-
-    function_names_.push_back(runtime::symbol::tvm_module_main);
-    stream << "// CodegenC: NOTE: Auto-generated entry function\n";
-    PrintFuncPrefix(stream);
-    PrintType(func->ret_type, stream);
-    stream << " " << tvm::runtime::symbol::tvm_module_main
-           << "(void* self, void* args,int num_args, void* result) {\n";
-    stream << "  return " << global_symbol.value() << "(self, args, num_args, 
result);\n";
-    stream << "}\n";
-  }
 }
 
 void CodeGenCHost::GenerateForwardFunctionDeclarations(String global_symbol,
diff --git a/src/tir/ir/transform.cc b/src/tir/ir/transform.cc
index aafe6277e2..6ef7ffdca9 100644
--- a/src/tir/ir/transform.cc
+++ b/src/tir/ir/transform.cc
@@ -42,7 +42,6 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_cse_tir", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_debug", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.enable_equiv_terms_in_cse_tir", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
-TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
 TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
diff --git a/src/tir/transforms/primfunc_utils.cc 
b/src/tir/transforms/primfunc_utils.cc
index b1f3476eab..99cf901377 100644
--- a/src/tir/transforms/primfunc_utils.cc
+++ b/src/tir/transforms/primfunc_utils.cc
@@ -29,45 +29,6 @@ namespace tvm {
 namespace tir {
 namespace transform {
 
-transform::Pass AnnotateEntryFunc() {
-  auto fpass = [](IRModule mod, transform::PassContext ctx) -> IRModule {
-    // If only a single function exists, that function must be the entry
-    if (mod->functions.size() == 1) {
-      auto [gvar, base_func] = *mod->functions.begin();
-      if (!base_func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
-        if (auto ptr = base_func.as<PrimFuncNode>()) {
-          mod->Update(gvar, WithAttr(GetRef<PrimFunc>(ptr), 
tir::attr::kIsEntryFunc, true));
-        }
-      }
-      return mod;
-    }
-
-    // If the module has multiple functions, but only one is exposed
-    // externally, that function must be the entry.
-    bool has_external_non_primfuncs = false;
-    IRModule with_annotations;
-    for (const auto& [gvar, base_func] : mod->functions) {
-      bool is_external = 
base_func->GetAttr<String>(tvm::attr::kGlobalSymbol).has_value();
-      if (is_external) {
-        if (auto ptr = base_func.as<PrimFuncNode>()) {
-          with_annotations->Add(gvar,
-                                WithAttr(GetRef<PrimFunc>(ptr), 
tir::attr::kIsEntryFunc, true));
-        } else {
-          has_external_non_primfuncs = true;
-        }
-      }
-    }
-    if (with_annotations->functions.size() == 1 && 
!has_external_non_primfuncs) {
-      mod->Update(with_annotations);
-      return mod;
-    }
-
-    // Default fallback, no annotations may be inferred.
-    return mod;
-  };
-  return tvm::transform::CreateModulePass(fpass, 0, "tir.AnnotateEntryFunc", 
{});
-}
-
 transform::Pass Filter(ffi::TypedFunction<bool(PrimFunc)> fcond) {
   auto fpass = [fcond](tir::PrimFunc f, IRModule m, transform::PassContext 
ctx) {
     if (fcond(f)) {
@@ -81,9 +42,7 @@ transform::Pass Filter(ffi::TypedFunction<bool(PrimFunc)> 
fcond) {
 
 TVM_FFI_STATIC_INIT_BLOCK({
   namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef()
-      .def("tir.transform.AnnotateEntryFunc", AnnotateEntryFunc)
-      .def("tir.transform.Filter", Filter);
+  refl::GlobalDef().def("tir.transform.Filter", Filter);
 });
 
 }  // namespace transform
diff --git a/tests/python/codegen/test_target_codegen_device.py 
b/tests/python/codegen/test_target_codegen_device.py
index 4dad03d700..0089e0bea6 100644
--- a/tests/python/codegen/test_target_codegen_device.py
+++ b/tests/python/codegen/test_target_codegen_device.py
@@ -95,7 +95,7 @@ def test_add_pipeline():
         dev = tvm.device(device, 0)
         target = tvm.target.Target(device, host)
         mhost = tvm.tir.build(sch.mod, target=target)
-        f = mhost.entry_func
+        f = mhost["main"]
         # launch the kernel.
         n = 1027
         a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev)
diff --git a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py 
b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
index fe7a615531..808c016cf8 100644
--- a/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
+++ b/tests/python/contrib/test_hexagon/test_async_dma_pipeline.py
@@ -15,7 +15,7 @@
 # specific language governing permissions and limitations
 # under the License.
 
-""" Test different strategies for loading data into vtcm before running HVX 
workloads. """
+"""Test different strategies for loading data into vtcm before running HVX 
workloads."""
 
 import numpy as np
 import pytest
@@ -289,9 +289,13 @@ def evaluate(
 
     if tvm.testing.utils.IS_IN_CI:
         # Run with reduced number and repeat for CI
-        timer = module.time_evaluator("__tvm_main__", hexagon_session.device, 
number=1, repeat=1)
+        timer = module.time_evaluator(
+            "__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1
+        )
     else:
-        timer = module.time_evaluator("__tvm_main__", hexagon_session.device, 
number=10, repeat=10)
+        timer = module.time_evaluator(
+            "__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10
+        )
 
     time = timer(a_hexagon, b_hexagon, c_hexagon)
     if expected_output is not None:
diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx.py 
b/tests/python/contrib/test_hexagon/test_parallel_hvx.py
index 8f77fa1c40..6822352568 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_hvx.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_hvx.py
@@ -160,7 +160,7 @@ def evaluate(hexagon_session, shape_dtypes, 
expected_output_producer, sch):
     repeat = 1
 
     timer = module.time_evaluator(
-        "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
+        "__tvm_ffi_main__", hexagon_session.device, number=number, 
repeat=repeat
     )
     runtime = timer(a_hexagon, b_hexagon, c_hexagon)
     tvm.testing.assert_allclose(c_hexagon.numpy(), 
expected_output_producer(c_shape, a, b))
diff --git a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py 
b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
index a584997dd5..63a65f3716 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_hvx_load_vtcm.py
@@ -330,9 +330,7 @@ def setup_and_run(hexagon_session, sch, a, b, c, 
operations, mem_scope="global")
     number = 1
     repeat = 1
 
-    timer = module.time_evaluator(
-        "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
-    )
+    timer = module.time_evaluator("main", hexagon_session.device, 
number=number, repeat=repeat)
     time = timer(a_hexagon, b_hexagon, c_hexagon)
     gops = round(operations * 128 * 3 / time.mean / 1e9, 4)
     return gops, c_hexagon.numpy()
@@ -364,9 +362,7 @@ def setup_and_run_preallocated(hexagon_session, sch, a, b, 
c, operations):
     number = 1
     repeat = 1
 
-    timer = module.time_evaluator(
-        "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
-    )
+    timer = module.time_evaluator("main", hexagon_session.device, 
number=number, repeat=repeat)
     time = timer(a_hexagon, b_hexagon, c_hexagon, a_vtcm_hexagon, 
b_vtcm_hexagon, c_vtcm_hexagon)
     gops = round(operations * 128 * 3 / time.mean / 1e9, 4)
     return gops, c_hexagon.numpy()
diff --git a/tests/python/contrib/test_hexagon/test_parallel_scalar.py 
b/tests/python/contrib/test_hexagon/test_parallel_scalar.py
index bd9c78d5da..5c8043fdff 100644
--- a/tests/python/contrib/test_hexagon/test_parallel_scalar.py
+++ b/tests/python/contrib/test_hexagon/test_parallel_scalar.py
@@ -104,9 +104,7 @@ def evaluate(hexagon_session, operations, expected, sch):
     number = 1
     repeat = 1
 
-    timer = module.time_evaluator(
-        "__tvm_main__", hexagon_session.device, number=number, repeat=repeat
-    )
+    timer = module.time_evaluator("main", hexagon_session.device, 
number=number, repeat=repeat)
     runtime = timer(a_hexagon, b_hexagon, c_hexagon)
 
     tvm.testing.assert_allclose(c_hexagon.numpy(), expected(a, b))
diff --git a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py 
b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
index 931f99b2ec..265f2bf5fd 100644
--- a/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
+++ b/tests/python/contrib/test_hexagon/test_vtcm_bandwidth.py
@@ -108,9 +108,13 @@ def evaluate(hexagon_session, sch, size):
 
     if tvm.testing.utils.IS_IN_CI:
         # Run with reduced number and repeat for CI
-        timer = module.time_evaluator("__tvm_main__", hexagon_session.device, 
number=1, repeat=1)
+        timer = module.time_evaluator(
+            "__tvm_ffi_main__", hexagon_session.device, number=1, repeat=1
+        )
     else:
-        timer = module.time_evaluator("__tvm_main__", hexagon_session.device, 
number=10, repeat=10)
+        timer = module.time_evaluator(
+            "__tvm_ffi_main__", hexagon_session.device, number=10, repeat=10
+        )
 
     runtime = timer(a_hexagon, a_vtcm_hexagon)
 
diff --git 
a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py 
b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
index 81acf5ee86..0d2f445cb8 100644
--- a/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
+++ b/tests/python/relax/nvshmem/test_runtime_builtin_kv_cache_transfer.py
@@ -170,7 +170,7 @@ def set_global_func(head_dim, dtype):
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
         f = tvm.tir.build(mod["main"], target=target)
-        builts.append(f.entry_func)
+        builts.append(f["main"])
 
     (
         ftranspose_append,
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
index 1941edeaa7..305fd18f35 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_cpu.py
@@ -140,7 +140,7 @@ def set_global_func(head_dim, dtype):
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
         f = tvm.tir.build(mod["main"], target=target)
-        builts.append(f.entry_func)
+        builts.append(f["main"])
 
     (
         ftranspose_append,
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index ffd3452292..e13ce1ca7b 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -156,7 +156,7 @@ def set_global_func():
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
         f = tvm.tir.build(mod["main"], target=target)
-        builts.append(f.entry_func)
+        builts.append(f["main"])
 
     (
         ftranspose_append,
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
index 2f726064a7..53044a786c 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_flashinfer.py
@@ -169,7 +169,7 @@ def set_global_func(dtype):
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
         f = tvm.tir.build(mod["main"], target=target)
-        builts.append(f.entry_func)
+        builts.append(f["main"])
 
     (
         ftranspose_append,
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py
index b2982abdb0..73a4d89dad 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_mla_tir.py
@@ -134,7 +134,7 @@ def set_global_func(dtype):
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
         f = tvm.tir.build(mod["main"], target=target)
-        builts.append(f.entry_func)
+        builts.append(f["main"])
 
     (
         ftranspose_append,
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 8cd3a73740..44169828e2 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -142,7 +142,7 @@ def set_global_func(head_dim, dtype):
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)
         f = tvm.tir.build(mod["main"], target=target)
-        builts.append(f.entry_func)
+        builts.append(f["main"])
 
     (
         ftranspose_append,
diff --git a/tests/python/relax/test_runtime_builtin_rnn_state.py 
b/tests/python/relax/test_runtime_builtin_rnn_state.py
index 095aba8b83..fe8c19257d 100644
--- a/tests/python/relax/test_runtime_builtin_rnn_state.py
+++ b/tests/python/relax/test_runtime_builtin_rnn_state.py
@@ -81,7 +81,7 @@ def set_global_func():
         with target:
             mod = dl.ApplyDefaultSchedule(dl.gpu.Fallback())(mod)  # pylint: 
disable=not-callable
         f = tvm.tir.build(mod["main"], target=target)
-        return f.entry_func
+        return f["main"]
 
     _f_tir_gets, _f_tir_sets = [], []
     for state in states:
diff --git a/tests/python/tir-transform/test_tir_transform_helpers.py 
b/tests/python/tir-transform/test_tir_transform_helpers.py
index 0bbd0e7160..a7ac11d7a9 100644
--- a/tests/python/tir-transform/test_tir_transform_helpers.py
+++ b/tests/python/tir-transform/test_tir_transform_helpers.py
@@ -21,27 +21,6 @@ from tvm.script import tir as T, ir as I
 import tvm.testing
 
 
-def test_annotate_entry_func_single_primfunc():
-    @tvm.script.ir_module
-    class MockModule:
-        @T.prim_func(private=True)
-        def func1(A: T.Buffer((16,), "float32")):
-            for i in T.serial(16):
-                if i == 5:
-                    if i == 5:
-                        A[i] = 0.0
-
-    mod = MockModule
-    assert mod
-    assert not mod["func1"].attrs
-    after = tvm.tir.transform.AnnotateEntryFunc()(mod)
-    assert (
-        after["func1"].attrs
-        and "tir.is_entry_func" in after["func1"].attrs
-        and after["func1"].attrs["tir.is_entry_func"]
-    )
-
-
 # Test module
 @tvm.script.ir_module
 class MockModule:
@@ -60,16 +39,6 @@ class MockModule:
                     A[i] = 0.0
 
 
-@pytest.mark.xfail
-def test_annotate_entry_func_multiple_primfunc():
-    mod = MockModule
-    assert mod
-    assert not mod["func1"].attrs
-    assert not mod["func2"].attrs
-    # This should fail
-    after = tvm.tir.transform.AnnotateEntryFunc()(mod)
-
-
 def test_bind_target():
     mod = MockModule
     assert mod
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py 
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index 0e1b328844..73ca1dad3b 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -196,7 +196,6 @@ def opt_gemm_mod_host():
             T.func_attr(
                 {
                     "tir.noalias": True,
-                    "tir.is_entry_func": True,
                     "calling_conv": 1,
                 }
             )
@@ -2242,7 +2241,6 @@ def opt_conv_tensorcore_mod_host():
             {
                 "tir.noalias": True,
                 "global_symbol": "default_function",
-                "tir.is_entry_func": True,
                 "calling_conv": 1,
             }
         )


Reply via email to