This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new b5928e775c [Unity] Fix CUDA graph rewrite var used before def (#14800) b5928e775c is described below commit b5928e775c308c9e42e2951ed5de726eb85d7ed1 Author: Wuwei Lin <wu...@apache.org> AuthorDate: Wed May 10 03:31:12 2023 -0700 [Unity] Fix CUDA graph rewrite var used before def (#14800) CUDA graph rewriting may result reordering of original bindings, for example when a variable is used as an input of the lifted function. If the variable comes from the output of another function, we need to make sure output unpacking is emitted. --- src/relax/transform/rewrite_cuda_graph.cc | 27 +++++++++-- src/runtime/relax_vm/cuda/cuda_graph_builtin.cc | 1 - .../relax/test_transform_rewrite_cuda_graph.py | 53 ++++++++++++---------- 3 files changed, 54 insertions(+), 27 deletions(-) diff --git a/src/relax/transform/rewrite_cuda_graph.cc b/src/relax/transform/rewrite_cuda_graph.cc index 9621d9ff58..42ec5fca9d 100644 --- a/src/relax/transform/rewrite_cuda_graph.cc +++ b/src/relax/transform/rewrite_cuda_graph.cc @@ -461,6 +461,9 @@ class CUDAGraphRewriter : public ExprMutator { } Expr ret_value = builder_->Emit(launch_subgraph); for (int i = 0; i < static_cast<int>(plan.outputs.size()); ++i) { + // The unpacked result is saved in the var_redef_. It will be emitted when 1) the var + // definition is the original IR is visited, or 2) the var is used as an input to another + // lifted function, whichever comes first. var_redef_[plan.outputs[i]] = TupleGetItem(ret_value, i); } @@ -471,9 +474,9 @@ class CUDAGraphRewriter : public ExprMutator { if (subgraph_launches_.count(op->var.get())) { LaunchSubgraph(op, subgraph_launches_[op->var.get()]); } - if (auto it = var_redef_.find(op->var.get()); it != var_redef_.end()) { - auto new_var = builder_->Emit(it->second, op->var->name_hint()); - var_remap_[op->var->vid] = new_var; + if (auto it = var_redef_.find(op->var.get()); + it != var_redef_.end() && !var_remap_.count(op->var->vid)) { + EmitRedef(op->var.get(), it->second); return; } if (lifted_bindings_.count(op->var.get())) { @@ -483,6 +486,24 @@ class CUDAGraphRewriter : public ExprMutator { ExprMutator::VisitBinding_(op); } + Expr VisitExpr_(const VarNode* op) final { + if (auto it = var_remap_.find(op->vid); it != var_remap_.end()) { + return it->second; + } + if (auto it = var_redef_.find(op); it != var_redef_.end()) { + // This is the case that the var is used as an input to another lifted when + // the original var definition is not visited yet. + return EmitRedef(op, it->second); + } + return GetRef<Expr>(op); + } + + Var EmitRedef(const VarNode* var, const Expr& redef) { + auto new_var = builder_->Emit(redef, var->name_hint()); + var_remap_[var->vid] = new_var; + return new_var; + } + std::unordered_map<const VarNode*, LiftedFunctionRewritePlan> subgraph_launches_; std::unordered_map<const VarNode*, Expr> var_redef_; std::unordered_set<const VarNode*> lifted_bindings_; diff --git a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc index 45342cf4ff..9d2025d647 100644 --- a/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc +++ b/src/runtime/relax_vm/cuda/cuda_graph_builtin.cc @@ -87,7 +87,6 @@ class CUDAGraphCache : public Object { ObjectRef RunOrCapture(VirtualMachine* vm, const ObjectRef& capture_func, ObjectRef args, int64_t entry_index) { if (auto it = capture_cache_.find(entry_index); it != capture_cache_.end()) { - LOG(INFO) << "HIT"; // Launch CUDA graph const auto& [states, cuda_graph] = it->second; cudaGraphExec_t cuda_graph_exec; diff --git a/tests/python/relax/test_transform_rewrite_cuda_graph.py b/tests/python/relax/test_transform_rewrite_cuda_graph.py index 4fc4d6f4a1..40c0a4a876 100644 --- a/tests/python/relax/test_transform_rewrite_cuda_graph.py +++ b/tests/python/relax/test_transform_rewrite_cuda_graph.py @@ -29,15 +29,11 @@ def test_rewrite_cuda_graph(): def exp(rxplaceholder: T.Buffer((T.int64(2), T.int64(4)), "float32"), compute: T.Buffer((T.int64(2), T.int64(4)), "float32")): # function attr dict T.func_attr({"tir.noalias": True, "global_symbol": "exp"}) - # body - # with T.block("root") for i0_i1_fused_0 in T.thread_binding(T.int64(1), thread="blockIdx.x"): for i0_i1_fused_1 in T.thread_binding(T.int64(8), thread="threadIdx.x"): with T.block("compute"): i0 = T.axis.spatial(T.int64(2), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) // T.int64(4)) i1 = T.axis.spatial(T.int64(4), (i0_i1_fused_0 * T.int64(8) + i0_i1_fused_1) % T.int64(4)) - T.reads(rxplaceholder[i0, i1]) - T.writes(compute[i0, i1]) compute[i0, i1] = T.exp(rxplaceholder[i0, i1], dtype="float32") @@ -54,12 +50,17 @@ def test_rewrite_cuda_graph(): alloc2: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, 0, R.shape([2, 4]), "float32") _4: R.Tuple = cls.exp(alloc1, alloc2) _5: R.Tuple = R.memory.kill_tensor(alloc1) - alloc3: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), "float32", 0) - _6 = cls.exp(alloc2, alloc3) + storage2: R.Object = R.memory.alloc_storage(R.shape([32]), 0, "global", "float32") + alloc3: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage2, 0, R.shape([2, 4]), "float32") + _6: R.Tuple = cls.exp(alloc2, alloc3) _7: R.Tuple = R.memory.kill_tensor(alloc2) - _8: R.Tuple = R.memory.kill_storage(storage) - _9: R.Tuple = R.memory.kill_storage(storage1) - return alloc3 + alloc4: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), "float32", 0) + _8 = cls.exp(alloc3, alloc4) + _9: R.Tuple = R.memory.kill_tensor(alloc3) + _10: R.Tuple = R.memory.kill_storage(storage) + _11: R.Tuple = R.memory.kill_storage(storage1) + _12: R.Tuple = R.memory.kill_storage(storage2) + return alloc4 @I.ir_module @@ -80,40 +81,46 @@ def test_rewrite_cuda_graph(): compute[i0, i1] = T.exp(rxplaceholder[i0, i1], dtype="float32") @R.function - def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object): - gv: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) - gv1: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) - gv2: R.Tuple(R.Object, R.Object) = (gv, gv1) - return gv2 + def cuda_graph_alloc() -> R.Tuple(R.Object, R.Object, R.Object): + storage: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) + storage1: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) + storage2: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) + gv: R.Tuple(R.Object, R.Object, R.Object) = (storage, storage1, storage2) + return gv @R.function - def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): + def cuda_graph_capture(alloc: R.Tensor((2, 4), dtype="float32"), alloc1: R.Tensor((2, 4), dtype="float32"), storage: R.Object, storage2: R.Object) -> R.Tuple(R.Tensor((2, 4), dtype="float32")): cls = Expected _2: R.Tuple = cls.exp(alloc, alloc1) _3: R.Tuple = R.memory.kill_tensor(alloc) alloc2: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), R.dtype("float32")) _4: R.Tuple = cls.exp(alloc1, alloc2) _5: R.Tuple = R.memory.kill_tensor(alloc1) - gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (alloc2,) + alloc3: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage2, 0, R.shape([2, 4]), "float32") + _6: R.Tuple = cls.exp(alloc2, alloc3) + _7: R.Tuple = R.memory.kill_tensor(alloc2) + gv: R.Tuple(R.Tensor((2, 4), dtype="float32")) = (alloc3,) return gv @R.function def main(x: R.Tensor((2, 4), dtype="float32")) -> R.Tensor((10,), dtype="float32"): cls = Expected - gv: R.Tuple(R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object),)) + gv: R.Tuple(R.Object, R.Object, R.Object) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.get_cached_alloc", (cls.cuda_graph_alloc, R.prim_value(0)), sinfo_args=(R.Tuple(R.Object, R.Object, R.Object),)) storage: R.Object = gv[0] alloc: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, 4]), R.dtype("float32")) _1: R.Tuple = cls.exp(x, alloc) storage1: R.Object = gv[1] alloc1: R.Tensor((2, 4), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2, 4]), R.dtype("float32")) - gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, (alloc, alloc1, storage), R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),)) - alloc2: R.Tensor((2, 4), dtype="float32") = gv1[0] - alloc3: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0)) - _6: R.Tuple = cls.exp(alloc2, alloc3) - _7: R.Tuple = R.memory.kill_tensor(alloc2) + storage2: R.Object = gv[2] + gv1: R.Tuple(R.Tensor((2, 4), dtype="float32")) = R.call_builtin_with_ctx("vm.builtin.cuda_graph.run_or_capture", (cls.cuda_graph_capture, (alloc, alloc1, storage, storage2), R.prim_value(0)), sinfo_args=(R.Tuple(R.Tensor((2, 4), dtype="float32")),)) + alloc3: R.Tensor((2, 4), dtype="float32") = gv1[0] + alloc4: R.Tensor((2, 4), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, 4]), R.dtype("float32"), R.prim_value(0)) + _6: R.Tuple = cls.exp(alloc3, alloc4) + _7: R.Tuple = R.memory.kill_tensor(alloc3) _8: R.Tuple = R.memory.kill_storage(storage) _9: R.Tuple = R.memory.kill_storage(storage1) - return alloc3 + _10: R.Tuple = R.memory.kill_storage(storage2) + return alloc4 # fmt: on after = relax.transform.RewriteCUDAGraph()(Before)