This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 56255f8dc5 [TIR] Move SplitHostDevice to before MakePackedAPI (#14986)
56255f8dc5 is described below

commit 56255f8dc52144734c52538403188d0f834c4a86
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Jun 8 02:08:56 2023 -0400

    [TIR] Move SplitHostDevice to before MakePackedAPI (#14986)
    
    * [TIR] Move SplitHostDevice to before MakePackedAPI
    
    This simplifies the logic used in MakePackedAPI, that it the last user
    of the host parameter in a function's target.  After MakePackedAPI,
    every PrimFunc has a "target" attribute without a "host".
    
    * Roofline plots, update location for SaveLoweredTIR
    
    * Update ethos-u tests to include host prior to MakeUnpackedAPI
---
 python/tvm/utils/roofline/__init__.py              | 12 +++-
 src/driver/driver_api.cc                           |  5 +-
 src/tir/transforms/make_packed_api.cc              | 11 +++-
 src/tir/transforms/make_unpacked_api.cc            | 10 +++-
 src/tir/transforms/split_host_device.cc            |  8 ---
 .../contrib/test_ethosu/test_encode_constants.py   |  4 +-
 .../test_ethosu/test_tir_to_cs_translator.py       |  8 ++-
 .../unittest/test_tir_transform_make_packed_api.py | 40 ++++++++++---
 .../test_tir_transform_make_unpacked_api.py        | 65 +++++++++++++++++++---
 .../test_tir_transform_split_host_device.py        | 14 ++---
 10 files changed, 136 insertions(+), 41 deletions(-)

diff --git a/python/tvm/utils/roofline/__init__.py 
b/python/tvm/utils/roofline/__init__.py
index 67d80eb052..45cc880c5b 100644
--- a/python/tvm/utils/roofline/__init__.py
+++ b/python/tvm/utils/roofline/__init__.py
@@ -51,10 +51,16 @@ def _create_args(mod: IRModule, dev: Device, func_name: str 
= "main", remote=Non
 
 @pass_instrument
 class SaveLoweredTIR:
-    """Save TIR functions from right before final lowering. Right now this
-    means right before tir.MakePackedAPI."""
+    """Save TIR functions for analysis.
 
-    def __init__(self, before_pass: str = "tir.MakePackedAPI"):
+    We need the TIR function in a form that can be handled by
+    `auto_scheduler.feature.named_features_from_primfunc`, but which
+    is the closest to the final lowered form as possible.  Right now this
+    means right before tir.SplitHostDevice.
+
+    """
+
+    def __init__(self, before_pass: str = "tir.SplitHostDevice"):
         """
         Parameters
         ----------
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index cfc7fa80c7..b75f173e0d 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -578,6 +578,9 @@ transform::Sequential MixedModulePassManager(IRModule 
mixed_mod, Target target)
     mixed_pass_list.push_back(tir::transform::InjectPTXLDG32());
   }
 
+  mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
+  mixed_pass_list.push_back(tir::transform::SplitHostDevice());
+
   bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
                           .value_or(relay::Executor::Create("graph", {}))
                           ->GetAttr<Bool>("unpacked-api")
@@ -590,8 +593,6 @@ transform::Sequential MixedModulePassManager(IRModule 
mixed_mod, Target target)
   mixed_pass_list.push_back(tir::transform::FP8StorageLegalize());
   mixed_pass_list.push_back(tir::transform::BF16StorageLegalize());
 
-  mixed_pass_list.push_back(tir::transform::AnnotateDeviceRegions());
-  mixed_pass_list.push_back(tir::transform::SplitHostDevice());
   mixed_pass_list.push_back(tir::transform::LowerDeviceKernelLaunch());
 
   return transform::Sequential(mixed_pass_list);
diff --git a/src/tir/transforms/make_packed_api.cc 
b/src/tir/transforms/make_packed_api.cc
index 062f1c0509..a6673a19ad 100644
--- a/src/tir/transforms/make_packed_api.cc
+++ b/src/tir/transforms/make_packed_api.cc
@@ -223,6 +223,14 @@ PrimFunc MakePackedAPI(PrimFunc func) {
   }();
   int target_device_type = target->GetTargetDeviceType();
 
+  // A function without a host target has already been lowered.
+  Target target_host;
+  if (auto opt = target->GetHost()) {
+    target_host = opt.value();
+  } else {
+    return func;
+  }
+
   auto* func_ptr = func.CopyOnWrite();
   const Stmt nop = Evaluate(0);
   int num_args = static_cast<int>(func_ptr->params.size());
@@ -325,7 +333,8 @@ PrimFunc MakePackedAPI(PrimFunc func) {
                         name_hint + "." + kv.first->name_hint);
   }
 
-  func = WithAttr(std::move(func), tvm::attr::kCallingConv, 
Integer(CallingConv::kCPackedFunc));
+  func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv, 
Integer(CallingConv::kCPackedFunc)},
+                                     {tvm::attr::kTarget, target_host}});
 
   Stmt body = RewriteReturn(func_ptr->body, v_out_ret_value, v_out_ret_tcode);
   body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
diff --git a/src/tir/transforms/make_unpacked_api.cc 
b/src/tir/transforms/make_unpacked_api.cc
index bdb3a953e9..4b1b3bf517 100644
--- a/src/tir/transforms/make_unpacked_api.cc
+++ b/src/tir/transforms/make_unpacked_api.cc
@@ -111,6 +111,14 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
   }();
   int target_device_type = target->GetTargetDeviceType();
 
+  // A function without a host target has already been lowered.
+  Target target_host;
+  if (auto opt = target->GetHost()) {
+    target_host = opt.value();
+  } else {
+    return func;
+  }
+
   auto* func_ptr = func.CopyOnWrite();
 
   // Setup device context
@@ -145,7 +153,7 @@ PrimFunc MakeUnpackedAPI(PrimFunc func) {
   func_ptr->buffer_map = Map<Var, Buffer>();
 
   // return the function.
-  return func;
+  return WithAttrs(std::move(func), {{tvm::attr::kTarget, target_host}});
 }
 
 namespace transform {
diff --git a/src/tir/transforms/split_host_device.cc 
b/src/tir/transforms/split_host_device.cc
index 2de831e8ad..29ecaa4e8e 100644
--- a/src/tir/transforms/split_host_device.cc
+++ b/src/tir/transforms/split_host_device.cc
@@ -99,10 +99,6 @@ class HostDeviceSplitter : public StmtMutator {
 };
 
 PrimFunc SplitHostDevice(PrimFunc func, IRModule* device_mod, const GlobalVar& 
gvar) {
-  auto opt_target = func->GetAttr<Target>(tvm::attr::kTarget);
-  ICHECK(opt_target) << "SplitHostDevice: Require the target attribute";
-  Target target = opt_target.value();
-
   auto global_symbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
   auto name_prefix = global_symbol.value_or(gvar->name_hint);
 
@@ -112,10 +108,6 @@ PrimFunc SplitHostDevice(PrimFunc func, IRModule* 
device_mod, const GlobalVar& g
     func.CopyOnWrite()->body = body;
   }
 
-  if (auto target_host = target->GetHost()) {
-    func = WithAttr(std::move(func), tvm::attr::kTarget, target_host.value());
-  }
-
   return func;
 }
 
diff --git a/tests/python/contrib/test_ethosu/test_encode_constants.py 
b/tests/python/contrib/test_ethosu/test_encode_constants.py
index 0309768452..6a8ff28e44 100644
--- a/tests/python/contrib/test_ethosu/test_encode_constants.py
+++ b/tests/python/contrib/test_ethosu/test_encode_constants.py
@@ -525,7 +525,9 @@ def test_constant_as_input():
     # nothing else was overrwritten.
     # With Target Hooks the TIR module needs a target attached
     # and lowered via make unpacked API.
-    tir_mod["main"] = tir_mod["main"].with_attr("target", 
tvm.target.Target("ethos-u"))
+    tir_mod["main"] = tir_mod["main"].with_attr(
+        "target", tvm.target.Target("ethos-u", host="ethos-u")
+    )
     tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
     tir_to_cs_translator.translate(tir_mod, params)
 
diff --git a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py 
b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
index 22f886a591..a293e26919 100644
--- a/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
+++ b/tests/python/contrib/test_ethosu/test_tir_to_cs_translator.py
@@ -255,7 +255,9 @@ def test_buffer_info_extraction():
         # With Target Hooks the TIR module needs a target attached
         # and lowered via make unpacked API.
         tir_mod = test_case["tir_module"]
-        tir_mod["main"] = tir_mod["main"].with_attr("target", 
tvm.target.Target("ethos-u"))
+        tir_mod["main"] = tir_mod["main"].with_attr(
+            "target", tvm.target.Target("ethos-u", host="ethos-u")
+        )
         tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
         buffer_info = tir_to_cs_translator.extract_buffer_info(tir_mod, 
test_case["param_dict"])
         for buffer_var, info in buffer_info.items():
@@ -959,7 +961,9 @@ def test_assign_addresses():
 
     for test_case in test_cases:
         tir_mod = test_case["tir_module"]
-        tir_mod["main"] = tir_mod["main"].with_attr("target", 
tvm.target.Target("ethos-u"))
+        tir_mod["main"] = tir_mod["main"].with_attr(
+            "target", tvm.target.Target("ethos-u", host="ethos-u")
+        )
         tir_mod = tvm.tir.transform.MakeUnpackedAPI()(tir_mod)
         candidate_regions_for_scratch = [5, 2, 1]
         (
diff --git a/tests/python/unittest/test_tir_transform_make_packed_api.py 
b/tests/python/unittest/test_tir_transform_make_packed_api.py
index 8af7efb596..34adcbb9ae 100644
--- a/tests/python/unittest/test_tir_transform_make_packed_api.py
+++ b/tests/python/unittest/test_tir_transform_make_packed_api.py
@@ -37,7 +37,7 @@ def test_makeapi():
     mod = tvm.tir.transform.Apply(
         lambda f: f.with_attr(
             {
-                "target": tvm.target.Target("llvm"),
+                "target": tvm.target.Target("llvm", host="llvm"),
                 "global_symbol": "main",
             }
         )
@@ -90,7 +90,9 @@ def test_variable_passed_from_args():
     stmt = ib.get()
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, 
not_device_context], stmt))
-    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", 
tvm.target.Target("llvm")))(mod)
+    mod = tvm.tir.transform.Apply(
+        lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm"))
+    )(mod)
     mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", 
"main"))(mod)
     func = tvm.tir.transform.MakePackedAPI()(mod)["main"]
 
@@ -132,7 +134,9 @@ def test_device_api_context_implicit_resource_handle():
     stmt = ib.get()
 
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([input_buffer, 
device_context], stmt))
-    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", 
tvm.target.Target("llvm")))(mod)
+    mod = tvm.tir.transform.Apply(
+        lambda f: f.with_attr("target", tvm.target.Target("llvm", host="llvm"))
+    )(mod)
     mod = tvm.tir.transform.Apply(lambda f: f.with_attr("global_symbol", 
"main"))(mod)
     func = tvm.tir.transform.MakePackedAPI()(mod)["main"]
 
@@ -161,7 +165,7 @@ def test_device_api_context_implicit_resource_handle():
 
 @pytest.mark.parametrize("use_global_symbol", [True, False])
 def test_no_op_when_global_symbol_is_absent(use_global_symbol):
-    func_attr = {"target": tvm.target.Target("llvm")}
+    func_attr = {"target": tvm.target.Target("llvm", host="llvm")}
     if use_global_symbol:
         func_attr["global_symbol"] = "main"
 
@@ -177,6 +181,28 @@ def 
test_no_op_when_global_symbol_is_absent(use_global_symbol):
         tvm.ir.assert_structural_equal(before, after)
 
 
+def test_target_host_removed():
+    """After MakePackedAPI, host-side target should be the host
+
+    MakePackedAPI is the last transform that requires both the device
+    and the host.  After MakePackedAPI, the target attribute should
+    only contain the host-side target.
+    """
+
+    host = tvm.target.Target("llvm")
+
+    @I.ir_module
+    class before:
+        @T.prim_func
+        def main(A: T.Buffer(1, "float32")):
+            T.func_attr({"global_symbol": "main", "target": T.target("cuda", 
host=host)})
+            T.evaluate(0)
+
+    after = tvm.tir.transform.MakePackedAPI()(before)
+    target_attr = after["main"].attrs["target"]
+    assert str(host) == str(target_attr)
+
+
 def test_internal_subroutine_call():
     """Internal subroutines should not use the PackedFunc API
 
@@ -190,7 +216,7 @@ def test_internal_subroutine_call():
     class before:
         @T.prim_func
         def main(A: T.Buffer(1, "float32")):
-            T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
+            T.func_attr({"global_symbol": "main", "target": T.target("llvm", 
host="llvm")})
             before.subroutine(A.data)
 
         @T.prim_func
@@ -222,12 +248,12 @@ def 
test_subroutine_call_to_externally_visible_subroutine():
     class before:
         @T.prim_func
         def main(A: T.Buffer(1, "float32")):
-            T.func_attr({"global_symbol": "main", "target": T.target("llvm")})
+            T.func_attr({"global_symbol": "main", "target": T.target("llvm", 
host="llvm")})
             before.subroutine(A.data)
 
         @T.prim_func
         def subroutine(A_data: T.handle("float32")):
-            T.func_attr({"global_symbol": "subroutine", "target": 
T.target("llvm")})
+            T.func_attr({"global_symbol": "subroutine", "target": 
T.target("llvm", host="llvm")})
             T.evaluate(A_data)
 
     after = tvm.tir.transform.MakePackedAPI()(before)
diff --git a/tests/python/unittest/test_tir_transform_make_unpacked_api.py 
b/tests/python/unittest/test_tir_transform_make_unpacked_api.py
index bb9fe8ab82..1931f7aef3 100644
--- a/tests/python/unittest/test_tir_transform_make_unpacked_api.py
+++ b/tests/python/unittest/test_tir_transform_make_unpacked_api.py
@@ -41,9 +41,8 @@ def mod(mod_without_attrs):
 
 
 def test_noop_if_not_global_symbol(mod_without_attrs):
-    before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", 
tvm.target.Target("llvm")))(
-        mod_without_attrs
-    )
+    target = tvm.target.Target("llvm", host="llvm")
+    before = tvm.tir.transform.Apply(lambda f: f.with_attr("target", 
target))(mod_without_attrs)
     after = tvm.tir.transform.MakeUnpackedAPI()(before)
     tvm.ir.assert_structural_equal(before, after)
 
@@ -59,7 +58,8 @@ def test_fails_if_no_target(mod_without_attrs):
 
 @tvm.testing.parametrize_targets("c", "llvm", "cuda")
 def test_device_setup(mod, target, dev):
-    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", 
tvm.target.Target(target)))(mod)
+    target = tvm.target.Target(target, host="llvm")
+    mod = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod)
     f = tvm.tir.transform.MakeUnpackedAPI()(mod)["main"]
     assert len(f.params) == 1
     assert f.params[0].name == "A"
@@ -138,6 +138,49 @@ def test_body():
     assert f.params[2].name == "A"
 
 
+class TestTargetHostRemoved(tvm.testing.CompareBeforeAfter):
+    """After MakeUnpackedAPI, host-side target should be the host
+
+    MakeUnpackedAPI is the last transform that requires both the device
+    and the host.  After MakeUnpackedAPI, the target attribute should
+    only contain the host-side target.
+    """
+
+    transform = tvm.tir.transform.MakeUnpackedAPI()
+
+    def before(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(A: T.Buffer(1, "float32")):
+                T.func_attr({"global_symbol": "main", "target": 
T.target("cuda", host="llvm")})
+                mod.subroutine(A.data)
+
+            @T.prim_func
+            def subroutine(A_data: T.handle("float32")):
+                T.func_attr({"target": T.target("cuda")})
+                T.evaluate(A_data)
+
+        return mod
+
+    def expected(self):
+        @I.ir_module
+        class mod:
+            @T.prim_func
+            def main(A_data: T.handle("float32")) -> T.int32:
+                T.func_attr({"global_symbol": "main", "target": 
T.target("llvm")})
+                T.attr("default", "device_id", 0)
+                T.attr("default", "device_type", 2)
+                mod.subroutine(A_data)
+
+            @T.prim_func
+            def subroutine(A_data: T.handle("float32")):
+                T.func_attr({"target": T.target("cuda")})
+                T.evaluate(A_data)
+
+        return mod
+
+
 class TestInternalSubroutineCall(tvm.testing.CompareBeforeAfter):
     """Internal subroutines do not require modification
 
@@ -153,7 +196,7 @@ class 
TestInternalSubroutineCall(tvm.testing.CompareBeforeAfter):
         class mod:
             @T.prim_func
             def main(A: T.Buffer(1, "float32")):
-                T.func_attr({"global_symbol": "main", "target": 
T.target("llvm")})
+                T.func_attr({"global_symbol": "main", "target": 
T.target("llvm", host="llvm")})
                 mod.subroutine(A.data)
 
             @T.prim_func
@@ -195,12 +238,14 @@ class 
TestSubroutineCallToExternallyVisibleSubroutine(tvm.testing.CompareBeforeA
         class mod:
             @T.prim_func
             def main(A: T.Buffer(1, "float32")):
-                T.func_attr({"global_symbol": "main", "target": 
T.target("llvm")})
+                T.func_attr({"global_symbol": "main", "target": 
T.target("llvm", host="llvm")})
                 mod.subroutine(A.data)
 
             @T.prim_func
             def subroutine(A_data: T.handle("float32")):
-                T.func_attr({"global_symbol": "subroutine", "target": 
T.target("llvm")})
+                T.func_attr(
+                    {"global_symbol": "subroutine", "target": T.target("llvm", 
host="llvm")}
+                )
                 T.evaluate(A_data)
 
         return mod
@@ -240,7 +285,7 @@ class 
TestCallExternallyVisibleSubroutineWithDLTensor(tvm.testing.CompareBeforeA
         class mod:
             @T.prim_func
             def main(A: T.Buffer(1, "float32")):
-                T.func_attr({"global_symbol": "main", "target": 
T.target("llvm")})
+                T.func_attr({"global_symbol": "main", "target": 
T.target("llvm", host="llvm")})
                 mod.subroutine(
                     T.tvm_stack_make_array(
                         A.data,
@@ -255,7 +300,9 @@ class 
TestCallExternallyVisibleSubroutineWithDLTensor(tvm.testing.CompareBeforeA
 
             @T.prim_func
             def subroutine(A: T.Buffer(1, "float32")):
-                T.func_attr({"global_symbol": "subroutine", "target": 
T.target("llvm")})
+                T.func_attr(
+                    {"global_symbol": "subroutine", "target": T.target("llvm", 
host="llvm")}
+                )
                 T.evaluate(A.data)
 
         return mod
diff --git a/tests/python/unittest/test_tir_transform_split_host_device.py 
b/tests/python/unittest/test_tir_transform_split_host_device.py
index 1599b9a031..60bfb8a718 100644
--- a/tests/python/unittest/test_tir_transform_split_host_device.py
+++ b/tests/python/unittest/test_tir_transform_split_host_device.py
@@ -46,6 +46,7 @@ def test_split_host_device_func_attr():
         [
             tvm.tir.transform.AnnotateDeviceRegions(),
             tvm.tir.transform.SplitHostDevice(),
+            tvm.tir.transform.MakePackedAPI(),
             tvm.tir.transform.LowerDeviceKernelLaunch(),
         ]
     )(mod)
@@ -111,7 +112,7 @@ class TestSplitHostDevice(BaseCompare):
         class mod:
             @T.prim_func
             def main(n: T.int32):
-                T.func_attr({"target": T.target("llvm -opt-level=0")})
+                T.func_attr({"target": T.target("cuda", host="llvm 
-opt-level=0")})
                 mod.main_kernel(n)
 
             @T.prim_func
@@ -168,20 +169,19 @@ class 
TestSplitHostDeviceWithoutFuncHostAttribute(BaseCompare):
         return mod
 
 
-class TestSplitHostDevice(BaseCompare):
+class TestSplitHostDeviceWithoutDeviceRegion(BaseCompare):
     """Like TestSplitHostDevice, but no device regions to extract
 
-    Even if there are no device regions, the host-side function should
-    still have its "target" attribute updated.
+    Because MakePackedAPI/MakeUnpackedAPI still require both the
+    device and host, SplitHostDevice does not modify the "target"
+    attribute.
     """
 
     def before():
         T.func_attr({"target": T.target("ext_dev", host="llvm")})
         T.evaluate(0)
 
-    def expected():
-        T.func_attr({"target": T.target("llvm")})
-        T.evaluate(0)
+    expected = before
 
 
 if __name__ == "__main__":

Reply via email to