This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3b26ce21c0 [TIR] Avoid duplicate GlobalVar names in SplitHostDevice 
(#15119)
3b26ce21c0 is described below

commit 3b26ce21c0af9b74d478a17ea9707e9f0bb6b148
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Jun 21 13:46:57 2023 -0500

    [TIR] Avoid duplicate GlobalVar names in SplitHostDevice (#15119)
    
    * [TIR] Avoid duplicate GlobalVar names in SplitHostDevice
    
    Previously, the names were de-duplicated within the device module.
    With a single-module lowering flow, the names must be de-duplicated
    between both the host and device module.
    
    * Updated expected TIR
---
 src/tir/transforms/split_host_device.cc            | 36 +++++++--------
 .../test_tir_transform_split_host_device.py        | 51 ++++++++++++++++++++++
 2 files changed, 70 insertions(+), 17 deletions(-)

diff --git a/src/tir/transforms/split_host_device.cc 
b/src/tir/transforms/split_host_device.cc
index 29ecaa4e8e..ac5dc7131d 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -43,8 +43,8 @@ namespace tir {
 
 class HostDeviceSplitter : public StmtMutator {
  public:
-  explicit HostDeviceSplitter(IRModule* device_mod, std::string name_prefix)
-      : device_mod_(device_mod), name_prefix_(name_prefix) {}
+  explicit HostDeviceSplitter(IRModule* device_mod, std::function<GlobalVar()> 
var_supply)
+      : device_mod_(device_mod), var_supply_(var_supply) {}
 
   Stmt VisitStmt_(const AttrStmtNode* op) final {
     if (op->attr_key == tvm::attr::kTarget) {
@@ -74,13 +74,7 @@ class HostDeviceSplitter : public StmtMutator {
       return params;
     }();
 
-    GlobalVar kernel_symbol_global = [&]() {
-      std::stringstream name;
-      name << name_prefix_ << "_kernel";
-      GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_);
-      return global_var_supply->FreshGlobal(name.str(), false);
-    }();
-
+    GlobalVar kernel_symbol_global = var_supply_();
     PrimFunc device_func(params, body);
     device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget, 
device_target},
                                                      {tir::attr::kNoAlias, 
Bool(true)},
@@ -94,15 +88,13 @@ class HostDeviceSplitter : public StmtMutator {
 
   // target ir module
   IRModule* device_mod_;
-  // function name hint
-  std::string name_prefix_;
+  // Generate new GlobalVar for the kernel
+  std::function<GlobalVar()> var_supply_;
 };
 
-PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& 
gvar) {
-  auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
-  auto name_prefix = global_symbol.value_or(gvar->name_hint);
-
-  HostDeviceSplitter splitter(device_mod, name_prefix);
+PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod,
+                         std::function<GlobalVar()> var_supply) {
+  HostDeviceSplitter splitter(device_mod, var_supply);
 
   if (auto body = splitter(func->body); !body.same_as(func->body)) {
     func.CopyOnWrite()->body = body;
@@ -115,13 +107,23 @@ namespace transform {
 
 Pass SplitHostDevice() {
   auto pass_func = [](IRModule mod, PassContext ctx) {
+    GlobalVarSupply global_var_supply(mod);
+
     IRModule device_mod = IRModule(Map<GlobalVar, BaseFunc>({}));
     IRModule updates = IRModule(Map<GlobalVar, BaseFunc>({}));
 
     for (const auto& [gvar, base_func] : mod->functions) {
       if (auto opt = base_func.as<PrimFunc>()) {
         PrimFunc func = opt.value();
-        func = SplitHostDevice(std::move(func), &device_mod, gvar);
+
+        auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+        auto name_prefix = global_symbol.value_or(gvar->name_hint);
+        auto kernel_name = name_prefix + "_kernel";
+        auto var_supply = [&global_var_supply, &kernel_name]() -> GlobalVar {
+          return global_var_supply->FreshGlobal(kernel_name, false);
+        };
+
+        func = SplitHostDevice(std::move(func), &device_mod, var_supply);
         if (!func.same_as(base_func)) {
           updates->Add(gvar, func);
         }
diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py 
b/tests/python/unittest/test_tir_transform_split_host_device.py
index 60bfb8a718..ca16fe908f 100644
--- a/tests/python/unittest/test_tir_transform_split_host_device.py
+++ b/tests/python/unittest/test_tir_transform_split_host_device.py
@@ -184,5 +184,56 @@ class TestSplitHostDeviceWithoutDeviceRegion(BaseCompare):
     expected = before
 
 
+class TestSplitHostDeviceNameCollision(BaseCompare):
+    """Like TestSplitHostDevice, but with the default name already taken
+
+    The default name is generated as `func.name + "_kernel"`.  If this
+    name is already taken by another function in the IRModule, then
+    SplitHostDevice should select a different name.
+    """
+
+    def before(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(n: T.int32):
+                T.func_attr({"target": T.target("cuda", host="llvm 
-opt-level=0")})
+                T.attr(T.target("cuda"), "target", 0)
+                T.evaluate(n)
+
+            @T.prim_func
+            def main_kernel():
+                T.func_attr({"target": T.target("llvm")})
+                T.evaluate(0)
+
+        return mod
+
+    def expected(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(n: T.int32):
+                T.func_attr({"target": T.target("cuda", host="llvm 
-opt-level=0")})
+                mod.main_kernel_1(n)
+
+            @T.prim_func
+            def main_kernel_1(n: T.int32):
+                T.func_attr(
+                    {
+                        "target": T.target("cuda"),
+                        "tir.noalias": T.bool(True),
+                        "tir.is_global_func": True,
+                    }
+                )
+                T.evaluate(n)
+
+            @T.prim_func
+            def main_kernel():
+                T.func_attr({"target": T.target("llvm")})
+                T.evaluate(0)
+
+        return mod
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to