This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3b26ce21c0 [TIR] Avoid duplicate GlobalVar names in SplitHostDevice
(#15119)
3b26ce21c0 is described below
commit 3b26ce21c0af9b74d478a17ea9707e9f0bb6b148
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Jun 21 13:46:57 2023 -0500
[TIR] Avoid duplicate GlobalVar names in SplitHostDevice (#15119)
* [TIR] Avoid duplicate GlobalVar names in SplitHostDevice
Previously, the names were de-duplicated within the device module.
With a single-module lowering flow, the names must be de-duplicated
between both the host and device module.
* Updated expected TIR
---
src/tir/transforms/split_host_device.cc | 36 +++++++--------
.../test_tir_transform_split_host_device.py | 51 ++++++++++++++++++++++
2 files changed, 70 insertions(+), 17 deletions(-)
diff --git a/src/tir/transforms/split_host_device.cc
b/src/tir/transforms/split_host_device.cc
index 29ecaa4e8e..ac5dc7131d 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -43,8 +43,8 @@ namespace tir {
class HostDeviceSplitter : public StmtMutator {
public:
- explicit HostDeviceSplitter(IRModule* device_mod, std::string name_prefix)
- : device_mod_(device_mod), name_prefix_(name_prefix) {}
+ explicit HostDeviceSplitter(IRModule* device_mod, std::function<GlobalVar()>
var_supply)
+ : device_mod_(device_mod), var_supply_(var_supply) {}
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == tvm::attr::kTarget) {
@@ -74,13 +74,7 @@ class HostDeviceSplitter : public StmtMutator {
return params;
}();
- GlobalVar kernel_symbol_global = [&]() {
- std::stringstream name;
- name << name_prefix_ << "_kernel";
- GlobalVarSupply global_var_supply = GlobalVarSupply(*device_mod_);
- return global_var_supply->FreshGlobal(name.str(), false);
- }();
-
+ GlobalVar kernel_symbol_global = var_supply_();
PrimFunc device_func(params, body);
device_func = WithAttrs(std::move(device_func), {{tvm::attr::kTarget,
device_target},
{tir::attr::kNoAlias,
Bool(true)},
@@ -94,15 +88,13 @@ class HostDeviceSplitter : public StmtMutator {
// target ir module
IRModule* device_mod_;
- // function name hint
- std::string name_prefix_;
+ // Generate new GlobalVar for the kernel
+ std::function<GlobalVar()> var_supply_;
};
-PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar&
gvar) {
- auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
- auto name_prefix = global_symbol.value_or(gvar->name_hint);
-
- HostDeviceSplitter splitter(device_mod, name_prefix);
+PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod,
+ std::function<GlobalVar()> var_supply) {
+ HostDeviceSplitter splitter(device_mod, var_supply);
if (auto body = splitter(func->body); !body.same_as(func->body)) {
func.CopyOnWrite()->body = body;
@@ -115,13 +107,23 @@ namespace transform {
Pass SplitHostDevice() {
auto pass_func = [](IRModule mod, PassContext ctx) {
+ GlobalVarSupply global_var_supply(mod);
+
IRModule device_mod = IRModule(Map<GlobalVar, BaseFunc>({}));
IRModule updates = IRModule(Map<GlobalVar, BaseFunc>({}));
for (const auto& [gvar, base_func] : mod->functions) {
if (auto opt = base_func.as<PrimFunc>()) {
PrimFunc func = opt.value();
- func = SplitHostDevice(std::move(func), &device_mod, gvar);
+
+ auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ auto name_prefix = global_symbol.value_or(gvar->name_hint);
+ auto kernel_name = name_prefix + "_kernel";
+ auto var_supply = [&global_var_supply, &kernel_name]() -> GlobalVar {
+ return global_var_supply->FreshGlobal(kernel_name, false);
+ };
+
+ func = SplitHostDevice(std::move(func), &device_mod, var_supply);
if (!func.same_as(base_func)) {
updates->Add(gvar, func);
}
diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py
b/tests/python/unittest/test_tir_transform_split_host_device.py
index 60bfb8a718..ca16fe908f 100644
--- a/tests/python/unittest/test_tir_transform_split_host_device.py
+++ b/tests/python/unittest/test_tir_transform_split_host_device.py
@@ -184,5 +184,56 @@ class TestSplitHostDeviceWithoutDeviceRegion(BaseCompare):
expected = before
+class TestSplitHostDeviceNameCollision(BaseCompare):
+ """Like TestSplitHostDevice, but with the default name already taken
+
+ The default name is generated as `func.name + "_kernel"`. If this
+ name is already taken by another function in the IRModule, then
+ SplitHostDevice should select a different name.
+ """
+
+ def before(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(n: T.int32):
+ T.func_attr({"target": T.target("cuda", host="llvm
-opt-level=0")})
+ T.attr(T.target("cuda"), "target", 0)
+ T.evaluate(n)
+
+ @T.prim_func
+ def main_kernel():
+ T.func_attr({"target": T.target("llvm")})
+ T.evaluate(0)
+
+ return mod
+
+ def expected(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(n: T.int32):
+ T.func_attr({"target": T.target("cuda", host="llvm
-opt-level=0")})
+ mod.main_kernel_1(n)
+
+ @T.prim_func
+ def main_kernel_1(n: T.int32):
+ T.func_attr(
+ {
+ "target": T.target("cuda"),
+ "tir.noalias": T.bool(True),
+ "tir.is_global_func": True,
+ }
+ )
+ T.evaluate(n)
+
+ @T.prim_func
+ def main_kernel():
+ T.func_attr({"target": T.target("llvm")})
+ T.evaluate(0)
+
+ return mod
+
+
if __name__ == "__main__":
tvm.testing.main()