This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 56255f8dc5 [TIR] Move SplitHostDevice to before MakePackedAPI (#14986)
56255f8dc5 is described below
commit 56255f8dc52144734c52538403188d0f834c4a86
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Jun 8 02:08:56 2023 -0400
[TIR] Move SplitHostDevice to before MakePackedAPI (#14986)
* [TIR] Move SplitHostDevice to before MakePackedAPI
This simplifies the logic used in MakePackedAPI, that it the last user
of the host parameter in a function's target. After MakePackedAPI,
every PrimFunc has a "target" attribute without a "host".
* Roofline plots, update location for SaveLoweredTIR
* Update ethos-u tests to include host prior to MakeUnpackedAPI
---
python/tvm/utils/roofline/__init__.py | 12 +++-
src/driver/driver_api.cc | 5 +-
src/tir/transforms/make_packed_api.cc | 11 +++-
src/tir/transforms/make_unpacked_api.cc | 10 +++-
src/tir/transforms/split_host_device.cc | 8 ---
.../contrib/test_ethosu/test_encode_constants.py | 4 +-
.../test_ethosu/test_tir_to_cs_translator.py | 8 ++-
.../unittest/test_tir_transform_make_packed_api.py | 40 ++++++++++---
.../test_tir_transform_make_unpacked_api.py | 65 +++++++++++++++++++---
.../test_tir_transform_split_host_device.py | 14 ++---
10 files changed, 136 insertions(+), 41 deletions(-)
diff --git a/python/tvm/utils/roofline/__init__.py
b/python/tvm/utils/roofline/__init__.py
index 67d80eb052..45cc880c5b 100644
--- a/python/tvm/utils/roofline/__init__.py
+++ b/python/tvm/utils/roofline/__init__.py
@@ -51,10 +51,16 @@ def _create_args(mod: IRModule, dev: Device, func_name: str
= "main", remote=Non
@pass_instrument
class SaveLoweredTIR:
- """Save TIR functions from right before final lowering. Right now this
- means right before tir.MakePackedAPI."""
+ """Save TIR functions for analysis.
- def __init__(self, before_pass: str = "tir.MakePackedAPI"):
+ We need the TIR function in a form that can be handled by
+ `auto_scheduler.feature.named_features_from_primfunc`, but which
+ is the closest to the final lowered form as possible. Right now this
+ means right before tir.SplitHostDevice.
+
+ """
+
+ def __init__(self, before_pass: str = "tir.SplitHostDevice"):
"""
Parameters
----------
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index cfc7fa80c7..b75f173e0d 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -578,6 +578,9 @@ transform::Sequential MixedModulePassManager(IRModule
mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::InjectPTXLDG32());
}
+ mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
+ mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+
bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
.value_or(relay::Executor::Create("graph", {}))
->GetAttr<Bool>("unpacked-api")
@@ -590,8 +593,6 @@ transform::Sequential MixedModulePassManager(IRModule
mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
- mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
- mixed_pass_list.push_back(tir::transform::SplitHostDevice());
mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());
return transform::Sequential(mixed_pass_list);
diff --git a/src/tir/transforms/make_packed_api.cc
b/src/tir/transforms/make_packed_api.cc
index 062f1c0509..a6673a19ad 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -223,6 +223,14 @@ PrimFunc MakePackedAPI(PrimFunc func) {
}();
int target_device_type = target->GetTargetDeviceType();
+ // A function without a host target has already been lowered.
+ Target target_host;
+ if (auto opt = target->GetHost()) {
+ target_host = opt.value();
+ } else {
+ return func;
+ }
+
auto* func_ptr = func.CopyOnWrite();
const Stmt nop = Evaluate(0);
int num_args = static_cast<int>(func_ptr->params.size());
@@ -325,7 +333,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
name_hint + "." + kv.first->name_hint);
}
- func = WithAttr(std::move(func), tvm::attr::kCallingConv,
Integer(CallingConv::kCPackedFunc));
+ func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv,
Integer(CallingConv::kCPackedFunc)},
+ {tvm::attr::kTarget, target_host}});
Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
diff --git a/src/tir/transforms/make_unpacked_api.cc
b/src/tir/transforms/make_unpacked_api.cc
index bdb3a953e9..4b1b3bf517 100644
--- a/src/tir/transforms/make_unpacked_api.cc
+++ b/src/tir/transforms/make_unpacked_api.cc
@@ -111,6 +111,14 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
}();
int target_device_type = target->GetTargetDeviceType();
+ // A function without a host target has already been lowered.
+ Target target_host;
+ if (auto opt = target->GetHost()) {
+ target_host = opt.value();
+ } else {
+ return func;
+ }
+
auto* func_ptr = func.CopyOnWrite();
// Setup device context
@@ -145,7 +153,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
func_ptr->buffer_map = Map<Var, Buffer>();
// return the function.
- return func;
+ return WithAttrs(std::move(func), {{tvm::attr::kTarget, target_host}});
}
namespace transform {
diff --git a/src/tir/transforms/split_host_device.cc
b/src/tir/transforms/split_host_device.cc
index 2de831e8ad..29ecaa4e8e 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -99,10 +99,6 @@ class HostDeviceSplitter : public StmtMutator {
};
PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar&
gvar) {
- auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
- ICHECK(opt_target) << "SplitHostDevice: Require the target attribute";
- Target target = opt_target.value();
-
auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
auto name_prefix = global_symbol.value_or(gvar->name_hint);
@@ -112,10 +108,6 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule*
device_mod, const GlobalVar& g
func.CopyOnWrite()->body = body;
}
- if (auto target_host = target->GetHost()) {
- func = WithAttr(std::move(func), tvm::attr::kTarget, target_host.value());
- }
-
return func;
}
diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py
b/tests/python/contrib/test_ethosu/test_encode_constants.py
index 0309768452..6a8ff28e44 100644
--- a/tests/python/contrib/test_ethosu/test_encode_constants.py
+++ b/tests/python/contrib/test_ethosu/test_encode_constants.py
@@ -525,7 +525,9 @@ def test_constant_as_input():
# nothing else was overrwritten.
# With Target Hooks the TIR module needs a target attached
# and lowered via make unpacked API.
- tir_mod["main"] = tir_mod["main"].with_attr("target",
tvm.target.Target("ethos-u"))
+ tir_mod["main"] = tir_mod["main"].with_attr(
+ "target", tvm.target.Target("ethos-u", host="ethos-u")
+ )
tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
tir_to_cs_translator.translate(tir_mod, params)
diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
index 22f886a591..a293e26919 100644
--- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
+++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
@@ -255,7 +255,9 @@ def test_buffer_info_extraction():
# With Target Hooks the TIR module needs a target attached
# and lowered via make unpacked API.
tir_mod = test_case["tir_module"]
- tir_mod["main"] = tir_mod["main"].with_attr("target",
tvm.target.Target("ethos-u"))
+ tir_mod["main"] = tir_mod["main"].with_attr(
+ "target", tvm.target.Target("ethos-u", host="ethos-u")
+ )
tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod,
test_case["param_dict"])
for buffer_var, info in buffer_info.items():
@@ -959,7 +961,9 @@ def test_assign_addresses():
for test_case in test_cases:
tir_mod = test_case["tir_module"]
- tir_mod["main"] = tir_mod["main"].with_attr("target",
tvm.target.Target("ethos-u"))
+ tir_mod["main"] = tir_mod["main"].with_attr(
+ "target", tvm.target.Target("ethos-u", host="ethos-u")
+ )
tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
candidate_regions_for_scratch = [5, 2, 1]
(
diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py
b/tests/python/unittest/test_tir_transform_make_packed_api.py
index 8af7efb596..34adcbb9ae 100644
--- a/tests/python/unittest/test_tir_transform_make_packed_api.py
+++ b/tests/python/unittest/test_tir_transform_make_packed_api.py
@@ -37,7 +37,7 @@ def test_makeapi():
mod = tvm.tir.transform.Apply(
lambda f: f.with_attr(
{
- "target": tvm.target.Target("llvm"),
+ "target": tvm.target.Target("llvm", host="llvm"),
"global_symbol": "main",
}
)
@@ -90,7 +90,9 @@ def test_variable_passed_from_args():
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer,
not_device_context], stmt))
- mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target",
tvm.target.Target("llvm")))(mod)
+ mod = tvm.tir.transform.Apply(
+ lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm"))
+ )(mod)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol",
"main"))(mod)
func = tvm.tir.transform.MakePackedAPI()(mod)["main"]
@@ -132,7 +134,9 @@ def test_device_api_context_implicit_resource_handle():
stmt = ib.get()
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer,
device_context], stmt))
- mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target",
tvm.target.Target("llvm")))(mod)
+ mod = tvm.tir.transform.Apply(
+ lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm"))
+ )(mod)
mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol",
"main"))(mod)
func = tvm.tir.transform.MakePackedAPI()(mod)["main"]
@@ -161,7 +165,7 @@ def test_device_api_context_implicit_resource_handle():
@pytest.mark.parametrize("use_global_symbol", [True, False])
def test_no_op_when_global_symbol_is_absent(use_global_symbol):
- func_attr = {"target": tvm.target.Target("llvm")}
+ func_attr = {"target": tvm.target.Target("llvm", host="llvm")}
if use_global_symbol:
func_attr["global_symbol"] = "main"
@@ -177,6 +181,28 @@ def
test_no_op_when_global_symbol_is_absent(use_global_symbol):
tvm.ir.assert_structural_equal(before, after)
+def test_target_host_removed():
+ """After MakePackedAPI, host-side target should be the host
+
+ MakePackedAPI is the last transform that requires both the device
+ and the host. After MakePackedAPI, the target attribute should
+ only contain the host-side target.
+ """
+
+ host = tvm.target.Target("llvm")
+
+ @I.ir_module
+ class before:
+ @T.prim_func
+ def main(A: T.Buffer(1, "float32")):
+ T.func_attr({"global_symbol": "main", "target": T.target("cuda",
host=host)})
+ T.evaluate(0)
+
+ after = tvm.tir.transform.MakePackedAPI()(before)
+ target_attr = after["main"].attrs["target"]
+ assert str(host) == str(target_attr)
+
+
def test_internal_subroutine_call():
"""Internal subroutines should not use the PackedFunc API
@@ -190,7 +216,7 @@ def test_internal_subroutine_call():
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
- T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
+ T.func_attr({"global_symbol": "main", "target": T.target("llvm",
host="llvm")})
before.subroutine(A.data)
@T.prim_func
@@ -222,12 +248,12 @@ def
test_subroutine_call_to_externally_visible_subroutine():
class before:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
- T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
+ T.func_attr({"global_symbol": "main", "target": T.target("llvm",
host="llvm")})
before.subroutine(A.data)
@T.prim_func
def subroutine(A_data: T.handle("float32")):
- T.func_attr({"global_symbol": "subroutine", "target":
T.target("llvm")})
+ T.func_attr({"global_symbol": "subroutine", "target":
T.target("llvm", host="llvm")})
T.evaluate(A_data)
after = tvm.tir.transform.MakePackedAPI()(before)
diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py
b/tests/python/unittest/test_tir_transform_make_unpacked_api.py
index bb9fe8ab82..1931f7aef3 100644
--- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py
+++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py
@@ -41,9 +41,8 @@ def mod(mod_without_attrs):
def test_noop_if_not_global_symbol(mod_without_attrs):
- before = tvm.tir.transform.Apply(lambda f: f.with_attr("target",
tvm.target.Target("llvm")))(
- mod_without_attrs
- )
+ target = tvm.target.Target("llvm", host="llvm")
+ before = tvm.tir.transform.Apply(lambda f: f.with_attr("target",
target))(mod_without_attrs)
after = tvm.tir.transform.MakeUnpackedAPI()(before)
tvm.ir.assert_structural_equal(before, after)
@@ -59,7 +58,8 @@ def test_fails_if_no_target(mod_without_attrs):
@tvm.testing.parametrize_targets("c", "llvm", "cuda")
def test_device_setup(mod, target, dev):
- mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target",
tvm.target.Target(target)))(mod)
+ target = tvm.target.Target(target, host="llvm")
+ mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
assert len(f.params) == 1
assert f.params[0].name == "A"
@@ -138,6 +138,49 @@ def test_body():
assert f.params[2].name == "A"
+class TestTargetHostRemoved(tvm.testing.CompareBeforeAfter):
+ """After MakeUnpackedAPI, host-side target should be the host
+
+ MakeUnpackedAPI is the last transform that requires both the device
+ and the host. After MakeUnpackedAPI, the target attribute should
+ only contain the host-side target.
+ """
+
+ transform = tvm.tir.transform.MakeUnpackedAPI()
+
+ def before(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(A: T.Buffer(1, "float32")):
+ T.func_attr({"global_symbol": "main", "target":
T.target("cuda", host="llvm")})
+ mod.subroutine(A.data)
+
+ @T.prim_func
+ def subroutine(A_data: T.handle("float32")):
+ T.func_attr({"target": T.target("cuda")})
+ T.evaluate(A_data)
+
+ return mod
+
+ def expected(self):
+ @I.ir_module
+ class mod:
+ @T.prim_func
+ def main(A_data: T.handle("float32")) -> T.int32:
+ T.func_attr({"global_symbol": "main", "target":
T.target("llvm")})
+ T.attr("default", "device_id", 0)
+ T.attr("default", "device_type", 2)
+ mod.subroutine(A_data)
+
+ @T.prim_func
+ def subroutine(A_data: T.handle("float32")):
+ T.func_attr({"target": T.target("cuda")})
+ T.evaluate(A_data)
+
+ return mod
+
+
class TestInternalSubroutineCall(tvm.testing.CompareBeforeAfter):
"""Internal subroutines do not require modification
@@ -153,7 +196,7 @@ class
TestInternalSubroutineCall(tvm.testing.CompareBeforeAfter):
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
- T.func_attr({"global_symbol": "main", "target":
T.target("llvm")})
+ T.func_attr({"global_symbol": "main", "target":
T.target("llvm", host="llvm")})
mod.subroutine(A.data)
@T.prim_func
@@ -195,12 +238,14 @@ class
TestSubroutineCallToExternallyVisibleSubroutine(tvm.testing.CompareBeforeA
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
- T.func_attr({"global_symbol": "main", "target":
T.target("llvm")})
+ T.func_attr({"global_symbol": "main", "target":
T.target("llvm", host="llvm")})
mod.subroutine(A.data)
@T.prim_func
def subroutine(A_data: T.handle("float32")):
- T.func_attr({"global_symbol": "subroutine", "target":
T.target("llvm")})
+ T.func_attr(
+ {"global_symbol": "subroutine", "target": T.target("llvm",
host="llvm")}
+ )
T.evaluate(A_data)
return mod
@@ -240,7 +285,7 @@ class
TestCallExternallyVisibleSubroutineWithDLTensor(tvm.testing.CompareBeforeA
class mod:
@T.prim_func
def main(A: T.Buffer(1, "float32")):
- T.func_attr({"global_symbol": "main", "target":
T.target("llvm")})
+ T.func_attr({"global_symbol": "main", "target":
T.target("llvm", host="llvm")})
mod.subroutine(
T.tvm_stack_make_array(
A.data,
@@ -255,7 +300,9 @@ class
TestCallExternallyVisibleSubroutineWithDLTensor(tvm.testing.CompareBeforeA
@T.prim_func
def subroutine(A: T.Buffer(1, "float32")):
- T.func_attr({"global_symbol": "subroutine", "target":
T.target("llvm")})
+ T.func_attr(
+ {"global_symbol": "subroutine", "target": T.target("llvm",
host="llvm")}
+ )
T.evaluate(A.data)
return mod
diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py
b/tests/python/unittest/test_tir_transform_split_host_device.py
index 1599b9a031..60bfb8a718 100644
--- a/tests/python/unittest/test_tir_transform_split_host_device.py
+++ b/tests/python/unittest/test_tir_transform_split_host_device.py
@@ -46,6 +46,7 @@ def test_split_host_device_func_attr():
[
tvm.tir.transform.AnnotateDeviceRegions(),
tvm.tir.transform.SplitHostDevice(),
+ tvm.tir.transform.MakePackedAPI(),
tvm.tir.transform.LowerDeviceKernelLaunch(),
]
)(mod)
@@ -111,7 +112,7 @@ class TestSplitHostDevice(BaseCompare):
class mod:
@T.prim_func
def main(n: T.int32):
- T.func_attr({"target": T.target("llvm -opt-level=0")})
+ T.func_attr({"target": T.target("cuda", host="llvm
-opt-level=0")})
mod.main_kernel(n)
@T.prim_func
@@ -168,20 +169,19 @@ class
TestSplitHostDeviceWithoutFuncHostAttribute(BaseCompare):
return mod
-class TestSplitHostDevice(BaseCompare):
+class TestSplitHostDeviceWithoutDeviceRegion(BaseCompare):
"""Like TestSplitHostDevice, but no device regions to extract
- Even if there are no device regions, the host-side function should
- still have its "target" attribute updated.
+ Because MakePackedAPI/MakeUnpackedAPI still require both the
+ device and host, SplitHostDevice does not modify the "target"
+ attribute.
"""
def before():
T.func_attr({"target": T.target("ext_dev", host="llvm")})
T.evaluate(0)
- def expected():
- T.func_attr({"target": T.target("llvm")})
- T.evaluate(0)
+ expected = before
if __name__ == "__main__":