This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 5ad8941252 [Unity] Add pass to allocate big workspace and pass it to all functions that need temp storage (#14802) 5ad8941252 is described below commit 5ad8941252fd0c3e7596863f594b0bb7b1bc5243 Author: masahi <masahi...@gmail.com> AuthorDate: Wed May 10 09:49:43 2023 +0900 [Unity] Add pass to allocate big workspace and pass it to all functions that need temp storage (#14802) * Add workspace allocation and rewriting pass for CUTLASS * fix when workspace is not needed * wip * rename to Allocateworkspace * minor * minor * fixed test * add test * add doc * black * zeros -> alloc_tensor for workspace --- include/tvm/relax/expr.h | 2 + python/tvm/contrib/cutlass/attention_operation.py | 9 +- python/tvm/contrib/cutlass/build.py | 19 +- python/tvm/contrib/cutlass/gen_tensor_op.py | 9 +- python/tvm/relax/backend/contrib/cutlass.py | 56 +++++- python/tvm/relax/transform/transform.py | 15 ++ src/relax/backend/vm/codegen_vm.cc | 6 +- src/relax/ir/block_builder.cc | 2 +- src/relax/op/op_common.h | 3 + src/relax/transform/allocate_workspace.cc | 199 +++++++++++++++++++++ tests/python/relax/test_codegen_cutlass.py | 38 ++-- .../relax/test_transform_allocate_workspace.py | 132 ++++++++++++++ 12 files changed, 448 insertions(+), 42 deletions(-) diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h index 0788193ee7..f090610019 100644 --- a/include/tvm/relax/expr.h +++ b/include/tvm/relax/expr.h @@ -983,6 +983,8 @@ constexpr const char* kCodegen = "Codegen"; constexpr const char* kComposite = "Composite"; /*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; +/*! \brief The required workspace for an external function. */ +constexpr const char* kWorkspaceSize = "WorkspaceSize"; } // namespace attr /*! \brief The extern function, which can represent packed function. */ diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 57c9ef4f91..c728f7fe4b 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -98,10 +98,7 @@ def instantiate_attention_template(attrs): p.output_ptr = reinterpret_cast<T *>(out0->data); p.output_accum_ptr = nullptr; if (Attention::kNeedsOutputAccumulatorBuffer) { - cudaMalloc( - &p.output_accum_ptr, - ${output_size} * sizeof(Attention::output_accum_t) - ); + p.output_accum_ptr = static_cast<float*>(${workspace}->data); } p.num_heads = ${num_heads}; // N @@ -131,10 +128,6 @@ def instantiate_attention_template(attrs): CHECK(Attention::check_supported(p)); kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p); - - if (Attention::kNeedsOutputAccumulatorBuffer) { - cudaFree(p.output_accum_ptr); - } """ template = substitute_template( diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 389dbf3e5c..519754d407 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -791,31 +791,38 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator): arg["arg0_dtype"] = signature["arg0_dtype"] arg["arg1_shape"] = q_shape = signature["arg1_shape"] - if "arg2_shape" not in signature: + if "arg3_shape" not in signature: + # arg0: qkv, arg1: shape, arg2: workspace arg["arg2_shape"] = k_shape = signature["arg1_shape"] arg["arg3_shape"] = v_shape = signature["arg1_shape"] else: - assert "arg3_shape" in signature + # arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4: workspace arg["arg2_shape"] = k_shape = signature["arg2_shape"] arg["arg3_shape"] = v_shape = signature["arg3_shape"] - if "arg4_dtype" in signature: + if "arg5_dtype" in signature: + # arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4: bias, arg5: workspace arg["bias_dtype"] = signature["arg4_dtype"] - if "arg4_shape" in signature: + if "arg5_shape" in signature: arg["bias_shape"] = signature["arg4_shape"] + qkv_layout = "qkv_stacked" else: + # arg0: q, arg1: k, arg2: v, arg3: bias, arg4: workspace arg["arg0_shape"] = q_shape = signature["arg0_shape"] arg["arg1_shape"] = k_shape = signature["arg1_shape"] arg["arg2_shape"] = v_shape = signature["arg2_shape"] arg["arg0_dtype"] = signature["arg0_dtype"] arg["arg1_dtype"] = signature["arg1_dtype"] arg["arg2_dtype"] = signature["arg2_dtype"] - if "arg3_dtype" in signature: + + if "arg4_dtype" in signature: arg["bias_dtype"] = signature["arg3_dtype"] - if "arg3_shape" in signature: + if "arg4_shape" in signature: arg["bias_shape"] = signature["arg3_shape"] + qkv_layout = "default" + out_shape = signature["ret_shape"] out_dtype = signature["ret_dtype"] num_batches, num_queries, num_heads, head_dim = q_shape diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 5e5ac621ef..f94d7ef467 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -727,15 +727,20 @@ def instantiate_template(func_name, annotations, func_args): attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) attrs["kSupportsDropout"] = False attrs["qkv_layout"] = annotations["qkv_layout"] + + for arg in func_args: + if "workspace" in arg: + attrs["workspace"] = arg + if attrs["qkv_layout"] == "default": attrs["query"] = func_args[0] attrs["key"] = func_args[1] attrs["value"] = func_args[2] - if len(func_args) > 3: + if len(func_args) > 4: # +1 for workspace, the last arg attrs["bias"] = func_args[3] elif attrs["qkv_layout"] == "qkv_stacked": attrs["qkv"] = func_args[0] - if len(func_args) > 4: + if len(func_args) > 5: # +1 for workspace, the last arg attrs["bias"] = func_args[4] else: raise NotImplementedError() diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 36f43c6c21..d5940ac5e4 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -16,11 +16,13 @@ # under the License. """Pattern table for CUTLASS backend""" - +import operator from typing import Mapping, Sequence +from functools import reduce +import tvm from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul -from tvm.relax import DataflowVar, Var, transform, Call +from tvm.relax import DataflowVar, Var, transform, Call, PyExprMutator, expr_functor, Function from tvm.relax.transform import PatternCheckContext from tvm.relax.dpl import rewrite_call @@ -373,6 +375,46 @@ register_patterns( _REWRITE_PATTERNS = [*attention_rewrite_patterns()] +@expr_functor.mutator +class WorkspaceAnnotator(PyExprMutator): + """Annotate a workspace requirement for each CUTLASS-offloaded function.""" + + def __init__(self, mod): + super().__init__(mod) + + def visit_function_(self, f): + if f.attrs is None or "Composite" not in f.attrs: + body = super().visit_expr(f.body) + new_f = Function(f.params, body, f.ret_struct_info, f.attrs, f.span) + + if f.attrs and "global_symbol" in f.attrs and "cutlass" in f.attrs["global_symbol"]: + composite_func = body.blocks[0].bindings[0].value + if "WorkspaceSize" in composite_func.attrs: + return new_f.with_attr("WorkspaceSize", composite_func.attrs["WorkspaceSize"]) + + return new_f + + if "attention" in f.attrs["Composite"]: + # Workspace is needed only for larger head sizes, but for simplicity we always allocate. + out_dtype = f.ret_struct_info.dtype + out_size_1d = reduce(operator.mul, f.ret_struct_info.shape, 1) + # This needs to be in sync with the actual value that the kernel expects. + workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 4}[out_dtype] + return f.with_attr("WorkspaceSize", workspace_size_bytes) + + return f + + +@tvm.transform.module_pass(opt_level=0) +def annotate_workspace(mod, _): + """Pass to annotate a workspace requirement for each CUTLASS-offloaded function.""" + annotator = WorkspaceAnnotator(mod) + for name, f in mod.functions.items(): + new_f = annotator.visit_expr(f) + mod.update_func(name, new_f) + return mod + + def partition_for_cutlass(mod, annotate_codegen=True): """ Partition the input module into CUTLASS-supported subgraphs. @@ -396,6 +438,12 @@ def partition_for_cutlass(mod, annotate_codegen=True): for pattern, rewriter in _REWRITE_PATTERNS: mod["main"] = rewrite_call(pattern, rewriter, mod["main"]) patterns = get_patterns_with_prefix("cutlass") - return transform.FuseOpsByPattern( - patterns, bind_constants=False, annotate_codegen=annotate_codegen + return tvm.transform.Sequential( + [ + transform.FuseOpsByPattern( + patterns, bind_constants=False, annotate_codegen=annotate_codegen + ), + annotate_workspace, + transform.AllocateWorkspace(), + ] )(mod) diff --git a/python/tvm/relax/transform/transform.py b/python/tvm/relax/transform/transform.py index b0d2710a99..508e8bccba 100644 --- a/python/tvm/relax/transform/transform.py +++ b/python/tvm/relax/transform/transform.py @@ -1016,6 +1016,21 @@ def RewriteCUDAGraph() -> tvm.ir.transform.Pass: return _ffi_api.RewriteCUDAGraph() # type: ignore +def AllocateWorkspace() -> tvm.ir.transform.Pass: + """Allocate a workspace, represented by a tensor of size big enough for all external + functions that require a temporary storage, and append it to the arguments of external + functions. + + An external function can specify its workspace requirement by the kWorkspaceSize attribute. + + Returns + ------- + ret: tvm.ir.transform.Pass + The registered pass for allocating workspace. + """ + return _ffi_api.AllocateWorkspace() # type: ignore + + def _wrap_class_function_pass(pass_cls, pass_info): """Wrap a python class as function pass.""" diff --git a/src/relax/backend/vm/codegen_vm.cc b/src/relax/backend/vm/codegen_vm.cc index 09f21cf751..c44300907f 100644 --- a/src/relax/backend/vm/codegen_vm.cc +++ b/src/relax/backend/vm/codegen_vm.cc @@ -70,11 +70,11 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const Expr&)> { IRModule res_mod = IRModule(Map<GlobalVar, BaseFunc>()); CodeGenVM codegen(builder, mod); // Remove relax function and turn into TIR func. - for (auto& p : mod->functions) { - if (auto* func = p.second.as<FunctionNode>()) { + for (const auto& [gvar, f] : mod->functions) { + if (auto* func = f.as<FunctionNode>()) { codegen.Codegen(GetRef<Function>(func)); } else { - res_mod->Add(p.first, p.second); + res_mod->Add(gvar, f); } } return res_mod; diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc index fe9e9bf8a5..5f9ce63c97 100644 --- a/src/relax/ir/block_builder.cc +++ b/src/relax/ir/block_builder.cc @@ -82,7 +82,7 @@ class BlockBuilderImpl : public BlockBuilderNode { while (context_mod_->ContainGlobalVar(func_name)) { func_name = GetUniqueName(func_name_hint); } - GlobalVar gvar = GlobalVar(func_name); + GlobalVar gvar(func_name); StructInfo finfo; if (func->struct_info_.defined()) { diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h index 8f5d1fbaa1..f7cff638cd 100644 --- a/src/relax/op/op_common.h +++ b/src/relax/op/op_common.h @@ -346,6 +346,9 @@ inline Optional<ShapeExpr> CheckNdimPerLayoutAndGetShape(const Call& call, const return NullOpt; } +Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm dtype); +Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm dtype); + } // namespace relax } // namespace tvm diff --git a/src/relax/transform/allocate_workspace.cc b/src/relax/transform/allocate_workspace.cc new file mode 100644 index 0000000000..b20f982efb --- /dev/null +++ b/src/relax/transform/allocate_workspace.cc @@ -0,0 +1,199 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file src/relax/transform/allocate_workspace.cc + * \brief Allocate a workspace and append it to the arguments of external functions, to + * satisfy their temporary storage requirement. + */ + +#include <tvm/ir/name_supply.h> +#include <tvm/relax/expr.h> +#include <tvm/relax/expr_functor.h> + +#include "../op/op_common.h" + +namespace tvm { +namespace relax { + +class ExternFunctionRewriter : ExprMutator { + public: + using ExprMutator::VisitExpr_; + + ExternFunctionRewriter(IRModule mod, size_t max_workspace_size) + : ExprMutator(mod), name_sup_(""), max_workspace_size_(max_workspace_size) {} + + std::unordered_map<const GlobalVarNode*, Function> Run() { + std::unordered_map<const GlobalVarNode*, Function> ret; + for (const auto& [gvar, f] : builder_->GetContextIRModule()->functions) { + if (f->GetAttr<Integer>(attr::kWorkspaceSize)) { + ret[gvar.get()] = Downcast<Function>(VisitExpr(f)); + } + } + return ret; + } + + Expr VisitExpr_(const FunctionNode* func_node) override { + if (!func_node->GetAttr<String>(attr::kCodegen) && + !func_node->GetAttr<String>(attr::kComposite)) { + return ExprMutator::VisitExpr_(func_node); + } + if (auto workspace = func_node->GetAttr<Integer>(attr::kWorkspaceSize)) { + // Append the workspace parameter to this function. + Array<Var> new_params = func_node->params; + + auto sinfo = TensorStructInfo(ShapeExpr({Integer(max_workspace_size_)}), DataType::UInt(8)); + Var workspace_param(name_sup_->FreshName("workspace"), sinfo); + + if (func_node->GetAttr<String>(attr::kCodegen)) { + workspace_var_param_ = workspace_param; + } + + new_params.push_back(workspace_param); + return Function(new_params, VisitExpr(func_node->body), func_node->ret_struct_info, + func_node->attrs); + } + return ExprMutator::VisitExpr_(func_node); + } + + Expr VisitExpr_(const CallNode* call_node) override { + auto new_op = VisitExpr(call_node->op); + if (auto var = new_op.as<Var>()) { + if (auto callee = builder_->LookupBinding(var.value()); + callee && callee->IsInstance<FunctionNode>() && + Downcast<Function>(callee.value())->GetAttr<String>(attr::kComposite)) { + // Append the workspace argument to this call. The callee should have been updated to accept + // a workspace as the last parameter. + auto new_args = call_node->args; + ICHECK(workspace_var_param_.defined()); + new_args.push_back(workspace_var_param_); + return Call(new_op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); + } + } + return ExprMutator::VisitExpr_(call_node); + } + + private: + NameSupply name_sup_; + /*! \brief A variable that represents the workspace parameter passed from main. */ + Var workspace_var_param_; + size_t max_workspace_size_ = 0; +}; + +class WorkspaceProvider : ExprMutator { + public: + explicit WorkspaceProvider(IRModule mod) : ExprMutator(mod), mod_(mod) {} + using ExprMutator::VisitBindingBlock_; + using ExprMutator::VisitExpr_; + + IRModule Run() { + for (const auto& [gvar, f] : mod_->functions) { + if (auto workspace = f->GetAttr<Integer>(relax::attr::kWorkspaceSize)) { + max_workspace_size_ = std::max<size_t>(max_workspace_size_, workspace.value()->value); + } + } + + if (max_workspace_size_ == 0) { + return mod_; + } + + auto new_funcs = relax::ExternFunctionRewriter(mod_, max_workspace_size_).Run(); + + for (const auto& [gvar, f] : new_funcs) { + auto new_gvar = builder_->AddFunction(f, gvar->name_hint); + // This is only required since the well-formed check requires kGlobalSymbol to be the same + // as the actual name of the global variable. + builder_->UpdateFunction(new_gvar, + WithAttr(f, tvm::attr::kGlobalSymbol, new_gvar->name_hint)); + gvar_map_[gvar] = new_gvar; + builder_->GetContextIRModule()->Remove(GetRef<GlobalVar>(gvar)); + } + + auto gvar = mod_->GetGlobalVar("main"); + auto func = Downcast<Function>(mod_->Lookup(gvar)); + auto new_func = + Function(func->params, VisitExpr(func->body), func->ret_struct_info, func->attrs); + builder_->UpdateFunction(gvar, new_func); + return builder_->GetContextIRModule(); + } + + BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final { + builder_->BeginDataflowBlock(); + if (!workspace_var_main_.defined()) { + auto shape = ShapeExpr({Integer(max_workspace_size_)}); + auto ty = DataTypeImm(DataType::UInt(8)); + auto storage = MakeVMAllocStorage(shape, PrimValue::Int64(0), ty); + auto workspace = MakeVMAllocTensor(storage, PrimValue::Int64(0), shape, ty); + workspace_var_main_ = builder_->Emit(workspace, "workspace_main"); + } + for (const auto& binding : block_node->bindings) { + this->VisitBinding(binding); + } + return builder_->EndBlock(); + } + + Expr VisitExpr_(const GlobalVarNode* gvar_node) override { + if (gvar_map_.count(gvar_node)) { + return gvar_map_[gvar_node]; + } + return ExprMutator::VisitExpr_(gvar_node); + } + + Expr VisitExpr_(const CallNode* call_node) override { + auto new_op = VisitExpr(call_node->op); + + if (auto gv = new_op.as<GlobalVar>()) { + auto callee = builder_->GetContextIRModule()->Lookup(gv.value()); + if (callee->HasNonzeroAttr(attr::kWorkspaceSize)) { + auto new_args = call_node->args; + ICHECK(workspace_var_main_.defined()); + new_args.push_back(workspace_var_main_); + return Call(new_op, new_args, call_node->attrs, call_node->sinfo_args, call_node->span); + } + } + + return ExprMutator::VisitExpr_(call_node); + } + + private: + IRModule mod_; + /*! \brief A variable that represents the workspace created at the beginning of main. */ + Var workspace_var_main_; + size_t max_workspace_size_ = 0; + /*! \brief A map from old global variables representing a function with workspace requirement to + * the new ones that are transformed to take an additional workspace parameter. This is only + * needed since the struct info of the global variables changes between transformation. */ + std::unordered_map<const GlobalVarNode*, GlobalVar> gvar_map_; +}; + +} // namespace relax + +namespace transform { + +Pass AllocateWorkspace() { + runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = + [=](IRModule m, PassContext pc) { return relax::WorkspaceProvider(m).Run(); }; + + return CreateModulePass(pass_func, 0, "AllocateWorkspace", {}); +} + +TVM_REGISTER_GLOBAL("relax.transform.AllocateWorkspace").set_body_typed(AllocateWorkspace); + +} // namespace transform +} // namespace tvm diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 45d66b3704..7a831c094e 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -83,9 +83,9 @@ cutlass_enabled = pytest.mark.skipif( pytestmark = [cutlass_enabled] -def build_and_run(mod, inputs_np, target, legalize=False): +def build_and_run(mod, inputs_np, target, legalize=True): if legalize: - mod = relax.transform.LegalizeOps()(mod) + mod = relax.transform.LegalizeOps()(mod) # For cpu reference, nop for cutlass. dev = tvm.device(target, 0) ex = relax.build(mod, target) @@ -95,11 +95,13 @@ def build_and_run(mod, inputs_np, target, legalize=False): return f(*inputs).numpy() -def get_result_with_relax_cutlass_offload(mod, *args, assert_all_bindings_fused=True): +def get_result_with_relax_cutlass_offload( + mod, *args, assert_all_bindings_fused=True, num_final_bindings=1 +): mod = partition_for_cutlass(mod) if assert_all_bindings_fused: - assert len(mod["main"].body.blocks[0].bindings) == 1 + assert len(mod["main"].body.blocks[0].bindings) == num_final_bindings codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}}) mod = codegen_pass(mod) @@ -116,7 +118,7 @@ def test_kernel_sharing(): out = get_result_with_relax_cutlass_offload( Conv2dx2, data_np, weight1_np, weight2_np, assert_all_bindings_fused=False ) - ref = build_and_run(Conv2dx2, [data_np, weight1_np, weight2_np], "llvm", legalize=True) + ref = build_and_run(Conv2dx2, [data_np, weight1_np, weight2_np], "llvm") np.testing.assert_equal(out, ref) @@ -243,7 +245,7 @@ def test_conv2d_offload(data_shape, weight_shape, dtype, epilogue, residual_bloc ) out = get_result_with_relax_cutlass_offload(mod, *args) - ref = build_and_run(mod, args, "llvm", legalize=True) + ref = build_and_run(mod, args, "llvm") tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5) @@ -369,7 +371,7 @@ def test_matmul_offload( residual_activation=residual_activation, ) out = get_result_with_relax_cutlass_offload(mod, *args) - ref = build_and_run(mod, args, "llvm", legalize=True) + ref = build_and_run(mod, args, "llvm") tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) @@ -616,7 +618,7 @@ def test_attention_offload(attention_size, attention_dtype): ) mod = get_relax_attention_module(q, k, v) - out = get_result_with_relax_cutlass_offload(mod, q, k, v) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) @@ -645,7 +647,7 @@ def test_attention_bias_offload(attention_bias_size): ) mod = get_relax_attention_module(q, k, v, bias) - out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) @@ -674,9 +676,9 @@ def test_attention_scale_offload(attention_scale_size, attention_scale): mod = get_relax_attention_module(q, k, v, bias, attention_scale) if bias is None: - out = get_result_with_relax_cutlass_offload(mod, q, k, v) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) else: - out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias) + out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) @@ -777,9 +779,9 @@ def test_stacked_attention_split_offload(stacked_attention_size): ) if bias is None: - out = get_result_with_relax_cutlass_offload(mod, qkv) + out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=3) else: - out = get_result_with_relax_cutlass_offload(mod, qkv, bias) + out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=3) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) @@ -795,9 +797,9 @@ def test_stacked_attention_strided_slice_offload(stacked_attention_size): qkv, b, s, n, h, h_v, "strided_slice", bias, scale, single_shape=single_shape ) if bias is None: - out = get_result_with_relax_cutlass_offload(mod, qkv) + out = get_result_with_relax_cutlass_offload(mod, qkv, num_final_bindings=3) else: - out = get_result_with_relax_cutlass_offload(mod, qkv, bias) + out = get_result_with_relax_cutlass_offload(mod, qkv, bias, num_final_bindings=3) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) @@ -966,8 +968,8 @@ def test_attention_rewrite_offload(attention_rewrite_size): expected_out = build_and_run(expected_mod, [q, k, v], "cuda") tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5, atol=1e-5) else: - original_out = build_and_run(original_mod, [q, k, v, bias], "cuda") - expected_out = build_and_run(expected_mod, [q, k, v, bias], "cuda") + original_out = build_and_run(original_mod, [q, k, v, bias], "cuda", legalize=False) + expected_out = build_and_run(expected_mod, [q, k, v, bias], "cuda", legalize=False) tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5, atol=1e-5) @@ -1043,7 +1045,7 @@ def test_layer_norm(data_shape, dtype, axes): gamma = np.random.randn(data_shape[-1]).astype(dtype) beta = np.random.randn(data_shape[-1]).astype(dtype) out = build_and_run(mod, [inp, gamma, beta], "cuda") - ref = build_and_run(Module, [inp, gamma, beta], "llvm", legalize=True) + ref = build_and_run(Module, [inp, gamma, beta], "llvm") tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) diff --git a/tests/python/relax/test_transform_allocate_workspace.py b/tests/python/relax/test_transform_allocate_workspace.py new file mode 100644 index 0000000000..7ffbd01b05 --- /dev/null +++ b/tests/python/relax/test_transform_allocate_workspace.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +import tvm.testing +from tvm import relax +from tvm.script import ir as I +from tvm.script import relax as R + + +@I.ir_module +class Module: + @R.function + def fused_relax_nn_attention_cutlass( + q: R.Tensor((32, 8, 16, 8), dtype="float16"), + k: R.Tensor((32, 8, 16, 8), dtype="float16"), + v: R.Tensor((32, 8, 16, 8), dtype="float16"), + ) -> R.Tensor((32, 8, 16, 8), dtype="float16"): + R.func_attr( + { + "Codegen": "cutlass", + "WorkspaceSize": 65536, + "global_symbol": "fused_relax_nn_attention_cutlass", + } + ) + + @R.function + def gv( + q_1: R.Tensor((32, 8, 16, 8), dtype="float16"), + k_1: R.Tensor((32, 8, 16, 8), dtype="float16"), + v_1: R.Tensor((32, 8, 16, 8), dtype="float16"), + ) -> R.Tensor((32, 8, 16, 8), dtype="float16"): + R.func_attr({"Composite": "cutlass.attention", "Primitive": 1, "WorkspaceSize": 65536}) + with R.dataflow(): + gv_2: R.Tensor((32, 8, 16, 8), dtype="float16") = R.nn.attention( + q_1, k_1, v_1, scale=None + ) + R.output(gv_2) + return gv_2 + + gv1: R.Tensor((32, 8, 16, 8), dtype="float16") = gv(q, k, v) + return gv1 + + @R.function + def main( + q: R.Tensor((32, 8, 16, 8), dtype="float16"), + k: R.Tensor((32, 8, 16, 8), dtype="float16"), + v: R.Tensor((32, 8, 16, 8), dtype="float16"), + ) -> R.Tensor((32, 8, 16, 8), dtype="float16"): + cls = Module + with R.dataflow(): + gv: R.Tensor((32, 8, 16, 8), dtype="float16") = cls.fused_relax_nn_attention_cutlass( + q, k, v + ) + R.output(gv) + return gv + + +@I.ir_module +class Expected: + @R.function + def fused_relax_nn_attention_cutlass1( + q: R.Tensor((32, 8, 16, 8), dtype="float16"), + k: R.Tensor((32, 8, 16, 8), dtype="float16"), + v: R.Tensor((32, 8, 16, 8), dtype="float16"), + workspace: R.Tensor((65536,), dtype="uint8"), + ) -> R.Tensor((32, 8, 16, 8), dtype="float16"): + R.func_attr( + { + "Codegen": "cutlass", + "WorkspaceSize": 65536, + "global_symbol": "fused_relax_nn_attention_cutlass1", + } + ) + + @R.function + def gv( + q_1: R.Tensor((32, 8, 16, 8), dtype="float16"), + k_1: R.Tensor((32, 8, 16, 8), dtype="float16"), + v_1: R.Tensor((32, 8, 16, 8), dtype="float16"), + workspace_1: R.Tensor((65536,), dtype="uint8"), + ) -> R.Tensor((32, 8, 16, 8), dtype="float16"): + R.func_attr({"Composite": "cutlass.attention", "Primitive": 1, "WorkspaceSize": 65536}) + with R.dataflow(): + gv_2: R.Tensor((32, 8, 16, 8), dtype="float16") = R.nn.attention( + q_1, k_1, v_1, scale=None + ) + R.output(gv_2) + return gv_2 + + gv1: R.Tensor((32, 8, 16, 8), dtype="float16") = gv(q, k, v, workspace) + return gv1 + + @R.function + def main( + q: R.Tensor((32, 8, 16, 8), dtype="float16"), + k: R.Tensor((32, 8, 16, 8), dtype="float16"), + v: R.Tensor((32, 8, 16, 8), dtype="float16"), + ) -> R.Tensor((32, 8, 16, 8), dtype="float16"): + cls = Expected + with R.dataflow(): + lv: R.Object = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), R.dtype("uint8")) + workspace_main: R.Tensor((65536,), dtype="uint8") = R.vm.alloc_tensor( + lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8") + ) + gv: R.Tensor((32, 8, 16, 8), dtype="float16") = cls.fused_relax_nn_attention_cutlass1( + q, k, v, workspace_main + ) + R.output(gv) + return gv + + +def test_single_attention(): + rewritten = relax.transform.AllocateWorkspace()(Module) + tvm.ir.assert_structural_equal(rewritten, Expected) + + +if __name__ == "__main__": + tvm.testing.main()