This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new b5928e775c [Unity] Fix CUDA graph rewrite var used before def (#14800)
b5928e775c is described below

commit b5928e775c308c9e42e2951ed5de726eb85d7ed1
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Wed May 10 03:31:12 2023 -0700

    [Unity] Fix CUDA graph rewrite var used before def (#14800)
    
    CUDA graph rewriting may result reordering of original bindings, for
    example when a variable is used as an input of the lifted function.
    If the variable comes from the output of another function, we need to
    make sure output unpacking is emitted.
---
 src/relax/transform/rewrite_cuda_graph.cc          | 27 +++++++++--
 src/runtime/relax_vm/cuda/cuda_graph_builtin.cc    |  1 -
 .../relax/test_transform_rewrite_cuda_graph.py     | 53 ++++++++++++----------
 3 files changed, 54 insertions(+), 27 deletions(-)

diff --git a/src/relax/transform/rewrite_cuda_graph.cc 
b/src/relax/transform/rewrite_cuda_graph.cc
index 9621d9ff58..42ec5fca9d 100644
--- a/src/relax/transform/rewrite_cuda_graph.cc
+++ b/src/relax/transform/rewrite_cuda_graph.cc
@@ -461,6 +461,9 @@ class CUDAGraphRewriter : public ExprMutator {
     }
     Expr ret_value = builder_->Emit(launch_subgraph);
     for (int i = 0; i < static_cast<int>(plan.outputs.size()); ++i) {
+      // The unpacked result is saved in the var_redef_. It will be emitted 
when 1) the var
+      // definition is the original IR is visited, or 2) the var is used as an 
input to another
+      // lifted function, whichever comes first.
       var_redef_[plan.outputs[i]] = TupleGetItem(ret_value, i);
     }
 
@@ -471,9 +474,9 @@ class CUDAGraphRewriter : public ExprMutator {
     if (subgraph_launches_.count(op->var.get())) {
       LaunchSubgraph(op, subgraph_launches_[op->var.get()]);
     }
-    if (auto it = var_redef_.find(op->var.get()); it != var_redef_.end()) {
-      auto new_var = builder_->Emit(it->second, op->var->name_hint());
-      var_remap_[op->var->vid] = new_var;
+    if (auto it = var_redef_.find(op->var.get());
+        it != var_redef_.end() && !var_remap_.count(op->var->vid)) {
+      EmitRedef(op->var.get(), it->second);
       return;
     }
     if (lifted_bindings_.count(op->var.get())) {
@@ -483,6 +486,24 @@ class CUDAGraphRewriter : public ExprMutator {
     ExprMutator::VisitBinding_(op);
   }
 
+  Expr VisitExpr_(const VarNode* op) final {
+    if (auto it = var_remap_.find(op->vid); it != var_remap_.end()) {
+      return it->second;
+    }
+    if (auto it = var_redef_.find(op); it != var_redef_.end()) {
+      // This is the case that the var is used as an input to another lifted 
when
+      // the original var definition is not visited yet.
+      return EmitRedef(op, it->second);
+    }
+    return GetRef<Expr>(op);
+  }
+
+  Var EmitRedef(const VarNode* var, const Expr& redef) {
+    auto new_var = builder_->Emit(redef, var->name_hint());
+    var_remap_[var->vid] = new_var;
+    return new_var;
+  }
+
   std::unordered_map<const VarNode*, LiftedFunctionRewritePlan> 
subgraph_launches_;
   std::unordered_map<const VarNode*, Expr> var_redef_;
   std::unordered_set<const VarNode*> lifted_bindings_;
diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc 
b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
index 45342cf4ff..9d2025d647 100644
--- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
+++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc
@@ -87,7 +87,6 @@ class CUDAGraphCache : public Object {
   ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, 
ObjectRef args,
                          int64_t entry_index) {
     if (auto it = capture_cache_.find(entry_index); it != 
capture_cache_.end()) {
-      LOG(INFO) << "HIT";
       // Launch CUDA graph
       const auto& [states, cuda_graph] = it->second;
       cudaGraphExec_t cuda_graph_exec;
diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py 
b/tests/python/relax/test_transform_rewrite_cuda_graph.py
index 4fc4d6f4a1..40c0a4a876 100644
--- a/tests/python/relax/test_transform_rewrite_cuda_graph.py
+++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py
@@ -29,15 +29,11 @@ def test_rewrite_cuda_graph():
         def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), 
compute: T.Buffer((T.int64(2), T.int64(4)), "float32")):
             # function attr dict
             T.func_attr({"tir.noalias": True, "global_symbol": "exp"})
-            # body
-            # with T.block("root")
             for i0_i1_fused_0 in T.thread_binding(T.int64(1), 
thread="blockIdx.x"):
                 for i0_i1_fused_1 in T.thread_binding(T.int64(8), 
thread="threadIdx.x"):
                     with T.block("compute"):
                         i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 * 
T.int64(8) + i0_i1_fused_1) // T.int64(4))
                         i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 * 
T.int64(8) + i0_i1_fused_1) % T.int64(4))
-                        T.reads(rxplaceholder[i0, i1])
-                        T.writes(compute[i0, i1])
                         compute[i0, i1] = T.exp(rxplaceholder[i0, i1], 
dtype="float32")
 
 
@@ -54,12 +50,17 @@ def test_rewrite_cuda_graph():
             alloc2: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), "float32")
             _4: R.Tuple = cls.exp(alloc1, alloc2)
             _5: R.Tuple = R.memory.kill_tensor(alloc1)
-            alloc3: R.Tensor((2, 4), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2, 4]), "float32", 0)
-            _6 = cls.exp(alloc2, alloc3)
+            storage2: R.Object = R.memory.alloc_storage(R.shape([32]), 0, 
"global", "float32")
+            alloc3: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage2, 0, R.shape([2, 4]), "float32")
+            _6: R.Tuple = cls.exp(alloc2, alloc3)
             _7: R.Tuple = R.memory.kill_tensor(alloc2)
-            _8: R.Tuple = R.memory.kill_storage(storage)
-            _9: R.Tuple = R.memory.kill_storage(storage1)
-            return alloc3
+            alloc4: R.Tensor((2, 4), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2, 4]), "float32", 0)
+            _8 = cls.exp(alloc3, alloc4)
+            _9: R.Tuple = R.memory.kill_tensor(alloc3)
+            _10: R.Tuple = R.memory.kill_storage(storage)
+            _11: R.Tuple = R.memory.kill_storage(storage1)
+            _12: R.Tuple = R.memory.kill_storage(storage2)
+            return alloc4
 
 
     @I.ir_module
@@ -80,40 +81,46 @@ def test_rewrite_cuda_graph():
                         compute[i0, i1] = T.exp(rxplaceholder[i0, i1], 
dtype="float32")
 
         @R.function
-        def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object):
-            gv: R.Object = R.memory.alloc_storage(R.shape([32]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
-            gv1: R.Object = R.memory.alloc_storage(R.shape([32]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
-            gv2: R.Tuple(R.Object, R.Object) = (gv, gv1)
-            return gv2
+        def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object):
+            storage: R.Object = R.memory.alloc_storage(R.shape([32]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            storage1: R.Object = R.memory.alloc_storage(R.shape([32]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            storage2: R.Object = R.memory.alloc_storage(R.shape([32]), 
R.prim_value(0), R.str("global"), R.dtype("float32"))
+            gv: R.Tuple(R.Object, R.Object, R.Object) = (storage, storage1, 
storage2)
+            return gv
 
         @R.function
-        def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), 
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> 
R.Tuple(R.Tensor((2, 4), dtype="float32")):
+        def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), 
alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object, storage2: 
R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")):
             cls = Expected
             _2: R.Tuple = cls.exp(alloc, alloc1)
             _3: R.Tuple = R.memory.kill_tensor(alloc)
             alloc2: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), 
R.dtype("float32"))
             _4: R.Tuple = cls.exp(alloc1, alloc2)
             _5: R.Tuple = R.memory.kill_tensor(alloc1)
-            gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (alloc2,)
+            alloc3: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage2, 0, R.shape([2, 4]), "float32")
+            _6: R.Tuple = cls.exp(alloc2, alloc3)
+            _7: R.Tuple = R.memory.kill_tensor(alloc2)
+            gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (alloc3,)
             return gv
 
         @R.function
         def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), 
dtype="float32"):
             cls = Expected
-            gv: R.Tuple(R.Object, R.Object) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", 
(cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, 
R.Object),))
+            gv: R.Tuple(R.Object, R.Object, R.Object) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", 
(cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, 
R.Object, R.Object),))
             storage: R.Object = gv[0]
             alloc: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), 
R.dtype("float32"))
             _1: R.Tuple = cls.exp(x, alloc)
             storage1: R.Object = gv[1]
             alloc1: R.Tensor((2, 4), dtype="float32") = 
R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]), 
R.dtype("float32"))
-            gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", 
(cls.cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), 
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
-            alloc2: R.Tensor((2, 4), dtype="float32") = gv1[0]
-            alloc3: R.Tensor((2, 4), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0))
-            _6: R.Tuple = cls.exp(alloc2, alloc3)
-            _7: R.Tuple = R.memory.kill_tensor(alloc2)
+            storage2: R.Object = gv[2]
+            gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = 
R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", 
(cls.cuda_graph_capture, (alloc, alloc1, storage, storage2), R.prim_value(0)), 
sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),))
+            alloc3: R.Tensor((2, 4), dtype="float32") = gv1[0]
+            alloc4: R.Tensor((2, 4), dtype="float32") = 
R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0))
+            _6: R.Tuple = cls.exp(alloc3, alloc4)
+            _7: R.Tuple = R.memory.kill_tensor(alloc3)
             _8: R.Tuple = R.memory.kill_storage(storage)
             _9: R.Tuple = R.memory.kill_storage(storage1)
-            return alloc3
+            _10: R.Tuple = R.memory.kill_storage(storage2)
+            return alloc4
     # fmt: on
 
     after = relax.transform.RewriteCUDAGraph()(Before)

Reply via email to