This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 5ad8941252 [Unity] Add pass to allocate big workspace and pass it to 
all functions that need temp storage   (#14802)
5ad8941252 is described below

commit 5ad8941252fd0c3e7596863f594b0bb7b1bc5243
Author: masahi <masahi...@gmail.com>
AuthorDate: Wed May 10 09:49:43 2023 +0900

    [Unity] Add pass to allocate big workspace and pass it to all functions 
that need temp storage   (#14802)
    
    * Add workspace allocation and rewriting pass for CUTLASS
    
    * fix when workspace is not needed
    
    * wip
    
    * rename to Allocateworkspace
    
    * minor
    
    * minor
    
    * fixed test
    
    * add test
    
    * add doc
    
    * black
    
    * zeros -> alloc_tensor for workspace
---
 include/tvm/relax/expr.h                           |   2 +
 python/tvm/contrib/cutlass/attention_operation.py  |   9 +-
 python/tvm/contrib/cutlass/build.py                |  19 +-
 python/tvm/contrib/cutlass/gen_tensor_op.py        |   9 +-
 python/tvm/relax/backend/contrib/cutlass.py        |  56 +++++-
 python/tvm/relax/transform/transform.py            |  15 ++
 src/relax/backend/vm/codegen_vm.cc                 |   6 +-
 src/relax/ir/block_builder.cc                      |   2 +-
 src/relax/op/op_common.h                           |   3 +
 src/relax/transform/allocate_workspace.cc          | 199 +++++++++++++++++++++
 tests/python/relax/test_codegen_cutlass.py         |  38 ++--
 .../relax/test_transform_allocate_workspace.py     | 132 ++++++++++++++
 12 files changed, 448 insertions(+), 42 deletions(-)

diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index 0788193ee7..f090610019 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -983,6 +983,8 @@ constexpr const char* kCodegen = "Codegen";
 constexpr const char* kComposite = "Composite";
 /*! \brief Indicate the function was created by the Pattern Partitioning Pass. 
*/
 constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern";
+/*! \brief The required workspace for an external function. */
+constexpr const char* kWorkspaceSize = "WorkspaceSize";
 }  // namespace attr
 
 /*! \brief The extern function, which can represent packed function. */
diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
index 57c9ef4f91..c728f7fe4b 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -98,10 +98,7 @@ def instantiate_attention_template(attrs):
   p.output_ptr = reinterpret_cast<T *>(out0->data);
   p.output_accum_ptr = nullptr;
   if (Attention::kNeedsOutputAccumulatorBuffer) {
-    cudaMalloc(
-      &p.output_accum_ptr,
-      ${output_size} * sizeof(Attention::output_accum_t)
-    );
+    p.output_accum_ptr = static_cast<float*>(${workspace}->data);
   }
 
   p.num_heads = ${num_heads}; // N
@@ -131,10 +128,6 @@ def instantiate_attention_template(attrs):
 
   CHECK(Attention::check_supported(p));
   kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
-
-  if (Attention::kNeedsOutputAccumulatorBuffer) {
-    cudaFree(p.output_accum_ptr);
-  }
 """
 
     template = substitute_template(
diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index 389dbf3e5c..519754d407 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -791,31 +791,38 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
             arg["arg0_dtype"] = signature["arg0_dtype"]
             arg["arg1_shape"] = q_shape = signature["arg1_shape"]
 
-            if "arg2_shape" not in signature:
+            if "arg3_shape" not in signature:
+                # arg0: qkv, arg1: shape, arg2: workspace
                 arg["arg2_shape"] = k_shape = signature["arg1_shape"]
                 arg["arg3_shape"] = v_shape = signature["arg1_shape"]
             else:
-                assert "arg3_shape" in signature
+                # arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4: 
workspace
                 arg["arg2_shape"] = k_shape = signature["arg2_shape"]
                 arg["arg3_shape"] = v_shape = signature["arg3_shape"]
 
-            if "arg4_dtype" in signature:
+            if "arg5_dtype" in signature:
+                # arg0: qkv, arg1: shape1, arg2: shape2, arg3: shape3, arg4: 
bias, arg5: workspace
                 arg["bias_dtype"] = signature["arg4_dtype"]
-            if "arg4_shape" in signature:
+            if "arg5_shape" in signature:
                 arg["bias_shape"] = signature["arg4_shape"]
+
             qkv_layout = "qkv_stacked"
         else:
+            # arg0: q, arg1: k, arg2: v,  arg3: bias, arg4: workspace
             arg["arg0_shape"] = q_shape = signature["arg0_shape"]
             arg["arg1_shape"] = k_shape = signature["arg1_shape"]
             arg["arg2_shape"] = v_shape = signature["arg2_shape"]
             arg["arg0_dtype"] = signature["arg0_dtype"]
             arg["arg1_dtype"] = signature["arg1_dtype"]
             arg["arg2_dtype"] = signature["arg2_dtype"]
-            if "arg3_dtype" in signature:
+
+            if "arg4_dtype" in signature:
                 arg["bias_dtype"] = signature["arg3_dtype"]
-            if "arg3_shape" in signature:
+            if "arg4_shape" in signature:
                 arg["bias_shape"] = signature["arg3_shape"]
+
             qkv_layout = "default"
+
         out_shape = signature["ret_shape"]
         out_dtype = signature["ret_dtype"]
         num_batches, num_queries, num_heads, head_dim = q_shape
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 5e5ac621ef..f94d7ef467 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -727,15 +727,20 @@ def instantiate_template(func_name, annotations, 
func_args):
         attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
         attrs["kSupportsDropout"] = False
         attrs["qkv_layout"] = annotations["qkv_layout"]
+
+        for arg in func_args:
+            if "workspace" in arg:
+                attrs["workspace"] = arg
+
         if attrs["qkv_layout"] == "default":
             attrs["query"] = func_args[0]
             attrs["key"] = func_args[1]
             attrs["value"] = func_args[2]
-            if len(func_args) > 3:
+            if len(func_args) > 4:  # +1 for workspace, the last arg
                 attrs["bias"] = func_args[3]
         elif attrs["qkv_layout"] == "qkv_stacked":
             attrs["qkv"] = func_args[0]
-            if len(func_args) > 4:
+            if len(func_args) > 5:  # +1 for workspace, the last arg
                 attrs["bias"] = func_args[4]
         else:
             raise NotImplementedError()
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 36f43c6c21..d5940ac5e4 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -16,11 +16,13 @@
 # under the License.
 
 """Pattern table for CUTLASS backend"""
-
+import operator
 from typing import Mapping, Sequence
+from functools import reduce
 
+import tvm
 from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
-from tvm.relax import DataflowVar, Var, transform, Call
+from tvm.relax import DataflowVar, Var, transform, Call, PyExprMutator, 
expr_functor, Function
 from tvm.relax.transform import PatternCheckContext
 from tvm.relax.dpl import rewrite_call
 
@@ -373,6 +375,46 @@ register_patterns(
 _REWRITE_PATTERNS = [*attention_rewrite_patterns()]
 
 
+@expr_functor.mutator
+class WorkspaceAnnotator(PyExprMutator):
+    """Annotate a workspace requirement for each CUTLASS-offloaded function."""
+
+    def __init__(self, mod):
+        super().__init__(mod)
+
+    def visit_function_(self, f):
+        if f.attrs is None or "Composite" not in f.attrs:
+            body = super().visit_expr(f.body)
+            new_f = Function(f.params, body, f.ret_struct_info, f.attrs, 
f.span)
+
+            if f.attrs and "global_symbol" in f.attrs and "cutlass" in 
f.attrs["global_symbol"]:
+                composite_func = body.blocks[0].bindings[0].value
+                if "WorkspaceSize" in composite_func.attrs:
+                    return new_f.with_attr("WorkspaceSize", 
composite_func.attrs["WorkspaceSize"])
+
+            return new_f
+
+        if "attention" in f.attrs["Composite"]:
+            # Workspace is needed only for larger head sizes, but for 
simplicity we always allocate.
+            out_dtype = f.ret_struct_info.dtype
+            out_size_1d = reduce(operator.mul, f.ret_struct_info.shape, 1)
+            # This needs to be in sync with the actual value that the kernel 
expects.
+            workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 
4}[out_dtype]
+            return f.with_attr("WorkspaceSize", workspace_size_bytes)
+
+        return f
+
+
+@tvm.transform.module_pass(opt_level=0)
+def annotate_workspace(mod, _):
+    """Pass to annotate a workspace requirement for each CUTLASS-offloaded 
function."""
+    annotator = WorkspaceAnnotator(mod)
+    for name, f in mod.functions.items():
+        new_f = annotator.visit_expr(f)
+        mod.update_func(name, new_f)
+    return mod
+
+
 def partition_for_cutlass(mod, annotate_codegen=True):
     """
     Partition the input module into CUTLASS-supported subgraphs.
@@ -396,6 +438,12 @@ def partition_for_cutlass(mod, annotate_codegen=True):
     for pattern, rewriter in _REWRITE_PATTERNS:
         mod["main"] = rewrite_call(pattern, rewriter, mod["main"])
     patterns = get_patterns_with_prefix("cutlass")
-    return transform.FuseOpsByPattern(
-        patterns, bind_constants=False, annotate_codegen=annotate_codegen
+    return tvm.transform.Sequential(
+        [
+            transform.FuseOpsByPattern(
+                patterns, bind_constants=False, 
annotate_codegen=annotate_codegen
+            ),
+            annotate_workspace,
+            transform.AllocateWorkspace(),
+        ]
     )(mod)
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index b0d2710a99..508e8bccba 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1016,6 +1016,21 @@ def RewriteCUDAGraph() -> tvm.ir.transform.Pass:
     return _ffi_api.RewriteCUDAGraph()  # type: ignore
 
 
+def AllocateWorkspace() -> tvm.ir.transform.Pass:
+    """Allocate a workspace, represented by a tensor of size big enough for 
all external
+    functions that require a temporary storage, and append it to the arguments 
of external
+    functions.
+
+    An external function can specify its workspace requirement by the 
kWorkspaceSize attribute.
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+        The registered pass for allocating workspace.
+    """
+    return _ffi_api.AllocateWorkspace()  # type: ignore
+
+
 def _wrap_class_function_pass(pass_cls, pass_info):
     """Wrap a python class as function pass."""
 
diff --git a/src/relax/backend/vm/codegen_vm.cc 
b/src/relax/backend/vm/codegen_vm.cc
index 09f21cf751..c44300907f 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -70,11 +70,11 @@ class CodeGenVM : public ExprFunctor<Instruction::Arg(const 
Expr&)> {
     IRModule res_mod = IRModule(Map<GlobalVar, BaseFunc>());
     CodeGenVM codegen(builder, mod);
     // Remove relax function and turn into TIR func.
-    for (auto& p : mod->functions) {
-      if (auto* func = p.second.as<FunctionNode>()) {
+    for (const auto& [gvar, f] : mod->functions) {
+      if (auto* func = f.as<FunctionNode>()) {
         codegen.Codegen(GetRef<Function>(func));
       } else {
-        res_mod->Add(p.first, p.second);
+        res_mod->Add(gvar, f);
       }
     }
     return res_mod;
diff --git a/src/relax/ir/block_builder.cc b/src/relax/ir/block_builder.cc
index fe9e9bf8a5..5f9ce63c97 100644
--- a/src/relax/ir/block_builder.cc
+++ b/src/relax/ir/block_builder.cc
@@ -82,7 +82,7 @@ class BlockBuilderImpl : public BlockBuilderNode {
       while (context_mod_->ContainGlobalVar(func_name)) {
         func_name = GetUniqueName(func_name_hint);
       }
-      GlobalVar gvar = GlobalVar(func_name);
+      GlobalVar gvar(func_name);
 
       StructInfo finfo;
       if (func->struct_info_.defined()) {
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 8f5d1fbaa1..f7cff638cd 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -346,6 +346,9 @@ inline Optional<ShapeExpr> 
CheckNdimPerLayoutAndGetShape(const Call& call, const
   return NullOpt;
 }
 
+Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm 
dtype);
+Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm 
dtype);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/src/relax/transform/allocate_workspace.cc 
b/src/relax/transform/allocate_workspace.cc
new file mode 100644
index 0000000000..b20f982efb
--- /dev/null
+++ b/src/relax/transform/allocate_workspace.cc
@@ -0,0 +1,199 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ *
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/transform/allocate_workspace.cc
+ * \brief Allocate a workspace and append it to the arguments of external 
functions, to
+ * satisfy their temporary storage requirement.
+ */
+
+#include <tvm/ir/name_supply.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+
+#include "../op/op_common.h"
+
+namespace tvm {
+namespace relax {
+
+class ExternFunctionRewriter : ExprMutator {
+ public:
+  using ExprMutator::VisitExpr_;
+
+  ExternFunctionRewriter(IRModule mod, size_t max_workspace_size)
+      : ExprMutator(mod), name_sup_(""), 
max_workspace_size_(max_workspace_size) {}
+
+  std::unordered_map<const GlobalVarNode*, Function> Run() {
+    std::unordered_map<const GlobalVarNode*, Function> ret;
+    for (const auto& [gvar, f] : builder_->GetContextIRModule()->functions) {
+      if (f->GetAttr<Integer>(attr::kWorkspaceSize)) {
+        ret[gvar.get()] = Downcast<Function>(VisitExpr(f));
+      }
+    }
+    return ret;
+  }
+
+  Expr VisitExpr_(const FunctionNode* func_node) override {
+    if (!func_node->GetAttr<String>(attr::kCodegen) &&
+        !func_node->GetAttr<String>(attr::kComposite)) {
+      return ExprMutator::VisitExpr_(func_node);
+    }
+    if (auto workspace = func_node->GetAttr<Integer>(attr::kWorkspaceSize)) {
+      // Append the workspace parameter to this function.
+      Array<Var> new_params = func_node->params;
+
+      auto sinfo = TensorStructInfo(ShapeExpr({Integer(max_workspace_size_)}), 
DataType::UInt(8));
+      Var workspace_param(name_sup_->FreshName("workspace"), sinfo);
+
+      if (func_node->GetAttr<String>(attr::kCodegen)) {
+        workspace_var_param_ = workspace_param;
+      }
+
+      new_params.push_back(workspace_param);
+      return Function(new_params, VisitExpr(func_node->body), 
func_node->ret_struct_info,
+                      func_node->attrs);
+    }
+    return ExprMutator::VisitExpr_(func_node);
+  }
+
+  Expr VisitExpr_(const CallNode* call_node) override {
+    auto new_op = VisitExpr(call_node->op);
+    if (auto var = new_op.as<Var>()) {
+      if (auto callee = builder_->LookupBinding(var.value());
+          callee && callee->IsInstance<FunctionNode>() &&
+          
Downcast<Function>(callee.value())->GetAttr<String>(attr::kComposite)) {
+        // Append the workspace argument to this call. The callee should have 
been updated to accept
+        // a workspace as the last parameter.
+        auto new_args = call_node->args;
+        ICHECK(workspace_var_param_.defined());
+        new_args.push_back(workspace_var_param_);
+        return Call(new_op, new_args, call_node->attrs, call_node->sinfo_args, 
call_node->span);
+      }
+    }
+    return ExprMutator::VisitExpr_(call_node);
+  }
+
+ private:
+  NameSupply name_sup_;
+  /*! \brief A variable that represents the workspace parameter passed from 
main. */
+  Var workspace_var_param_;
+  size_t max_workspace_size_ = 0;
+};
+
+class WorkspaceProvider : ExprMutator {
+ public:
+  explicit WorkspaceProvider(IRModule mod) : ExprMutator(mod), mod_(mod) {}
+  using ExprMutator::VisitBindingBlock_;
+  using ExprMutator::VisitExpr_;
+
+  IRModule Run() {
+    for (const auto& [gvar, f] : mod_->functions) {
+      if (auto workspace = f->GetAttr<Integer>(relax::attr::kWorkspaceSize)) {
+        max_workspace_size_ = std::max<size_t>(max_workspace_size_, 
workspace.value()->value);
+      }
+    }
+
+    if (max_workspace_size_ == 0) {
+      return mod_;
+    }
+
+    auto new_funcs = relax::ExternFunctionRewriter(mod_, 
max_workspace_size_).Run();
+
+    for (const auto& [gvar, f] : new_funcs) {
+      auto new_gvar = builder_->AddFunction(f, gvar->name_hint);
+      // This is only required since the well-formed check requires 
kGlobalSymbol to be the same
+      // as the actual name of the global variable.
+      builder_->UpdateFunction(new_gvar,
+                               WithAttr(f, tvm::attr::kGlobalSymbol, 
new_gvar->name_hint));
+      gvar_map_[gvar] = new_gvar;
+      builder_->GetContextIRModule()->Remove(GetRef<GlobalVar>(gvar));
+    }
+
+    auto gvar = mod_->GetGlobalVar("main");
+    auto func = Downcast<Function>(mod_->Lookup(gvar));
+    auto new_func =
+        Function(func->params, VisitExpr(func->body), func->ret_struct_info, 
func->attrs);
+    builder_->UpdateFunction(gvar, new_func);
+    return builder_->GetContextIRModule();
+  }
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final {
+    builder_->BeginDataflowBlock();
+    if (!workspace_var_main_.defined()) {
+      auto shape = ShapeExpr({Integer(max_workspace_size_)});
+      auto ty = DataTypeImm(DataType::UInt(8));
+      auto storage = MakeVMAllocStorage(shape, PrimValue::Int64(0), ty);
+      auto workspace = MakeVMAllocTensor(storage, PrimValue::Int64(0), shape, 
ty);
+      workspace_var_main_ = builder_->Emit(workspace, "workspace_main");
+    }
+    for (const auto& binding : block_node->bindings) {
+      this->VisitBinding(binding);
+    }
+    return builder_->EndBlock();
+  }
+
+  Expr VisitExpr_(const GlobalVarNode* gvar_node) override {
+    if (gvar_map_.count(gvar_node)) {
+      return gvar_map_[gvar_node];
+    }
+    return ExprMutator::VisitExpr_(gvar_node);
+  }
+
+  Expr VisitExpr_(const CallNode* call_node) override {
+    auto new_op = VisitExpr(call_node->op);
+
+    if (auto gv = new_op.as<GlobalVar>()) {
+      auto callee = builder_->GetContextIRModule()->Lookup(gv.value());
+      if (callee->HasNonzeroAttr(attr::kWorkspaceSize)) {
+        auto new_args = call_node->args;
+        ICHECK(workspace_var_main_.defined());
+        new_args.push_back(workspace_var_main_);
+        return Call(new_op, new_args, call_node->attrs, call_node->sinfo_args, 
call_node->span);
+      }
+    }
+
+    return ExprMutator::VisitExpr_(call_node);
+  }
+
+ private:
+  IRModule mod_;
+  /*! \brief A variable that represents the workspace created at the beginning 
of main. */
+  Var workspace_var_main_;
+  size_t max_workspace_size_ = 0;
+  /*! \brief A map from old global variables representing a function with 
workspace requirement to
+   * the new ones that are transformed to take an additional workspace 
parameter. This is only
+   * needed since the struct info of the global variables changes between 
transformation. */
+  std::unordered_map<const GlobalVarNode*, GlobalVar> gvar_map_;
+};
+
+}  // namespace relax
+
+namespace transform {
+
+Pass AllocateWorkspace() {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+      [=](IRModule m, PassContext pc) { return 
relax::WorkspaceProvider(m).Run(); };
+
+  return CreateModulePass(pass_func, 0, "AllocateWorkspace", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.AllocateWorkspace").set_body_typed(AllocateWorkspace);
+
+}  // namespace transform
+}  // namespace tvm
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 45d66b3704..7a831c094e 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -83,9 +83,9 @@ cutlass_enabled = pytest.mark.skipif(
 pytestmark = [cutlass_enabled]
 
 
-def build_and_run(mod, inputs_np, target, legalize=False):
+def build_and_run(mod, inputs_np, target, legalize=True):
     if legalize:
-        mod = relax.transform.LegalizeOps()(mod)
+        mod = relax.transform.LegalizeOps()(mod)  # For cpu reference, nop for 
cutlass.
 
     dev = tvm.device(target, 0)
     ex = relax.build(mod, target)
@@ -95,11 +95,13 @@ def build_and_run(mod, inputs_np, target, legalize=False):
     return f(*inputs).numpy()
 
 
-def get_result_with_relax_cutlass_offload(mod, *args, 
assert_all_bindings_fused=True):
+def get_result_with_relax_cutlass_offload(
+    mod, *args, assert_all_bindings_fused=True, num_final_bindings=1
+):
     mod = partition_for_cutlass(mod)
 
     if assert_all_bindings_fused:
-        assert len(mod["main"].body.blocks[0].bindings) == 1
+        assert len(mod["main"].body.blocks[0].bindings) == num_final_bindings
 
     codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, 
"find_first_valid": True}})
     mod = codegen_pass(mod)
@@ -116,7 +118,7 @@ def test_kernel_sharing():
     out = get_result_with_relax_cutlass_offload(
         Conv2dx2, data_np, weight1_np, weight2_np, 
assert_all_bindings_fused=False
     )
-    ref = build_and_run(Conv2dx2, [data_np, weight1_np, weight2_np], "llvm", 
legalize=True)
+    ref = build_and_run(Conv2dx2, [data_np, weight1_np, weight2_np], "llvm")
 
     np.testing.assert_equal(out, ref)
 
@@ -243,7 +245,7 @@ def test_conv2d_offload(data_shape, weight_shape, dtype, 
epilogue, residual_bloc
     )
     out = get_result_with_relax_cutlass_offload(mod, *args)
 
-    ref = build_and_run(mod, args, "llvm", legalize=True)
+    ref = build_and_run(mod, args, "llvm")
 
     tvm.testing.assert_allclose(out, ref, rtol=1e-5, atol=1e-5)
 
@@ -369,7 +371,7 @@ def test_matmul_offload(
         residual_activation=residual_activation,
     )
     out = get_result_with_relax_cutlass_offload(mod, *args)
-    ref = build_and_run(mod, args, "llvm", legalize=True)
+    ref = build_and_run(mod, args, "llvm")
 
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
@@ -616,7 +618,7 @@ def test_attention_offload(attention_size, attention_dtype):
     )
 
     mod = get_relax_attention_module(q, k, v)
-    out = get_result_with_relax_cutlass_offload(mod, q, k, v)
+    out = get_result_with_relax_cutlass_offload(mod, q, k, v, 
num_final_bindings=3)
 
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
@@ -645,7 +647,7 @@ def test_attention_bias_offload(attention_bias_size):
     )
 
     mod = get_relax_attention_module(q, k, v, bias)
-    out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
+    out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, 
num_final_bindings=3)
 
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
@@ -674,9 +676,9 @@ def test_attention_scale_offload(attention_scale_size, 
attention_scale):
 
     mod = get_relax_attention_module(q, k, v, bias, attention_scale)
     if bias is None:
-        out = get_result_with_relax_cutlass_offload(mod, q, k, v)
+        out = get_result_with_relax_cutlass_offload(mod, q, k, v, 
num_final_bindings=3)
     else:
-        out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
+        out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, 
num_final_bindings=3)
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
@@ -777,9 +779,9 @@ def 
test_stacked_attention_split_offload(stacked_attention_size):
         )
 
     if bias is None:
-        out = get_result_with_relax_cutlass_offload(mod, qkv)
+        out = get_result_with_relax_cutlass_offload(mod, qkv, 
num_final_bindings=3)
     else:
-        out = get_result_with_relax_cutlass_offload(mod, qkv, bias)
+        out = get_result_with_relax_cutlass_offload(mod, qkv, bias, 
num_final_bindings=3)
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
@@ -795,9 +797,9 @@ def 
test_stacked_attention_strided_slice_offload(stacked_attention_size):
             qkv, b, s, n, h, h_v, "strided_slice", bias, scale, 
single_shape=single_shape
         )
     if bias is None:
-        out = get_result_with_relax_cutlass_offload(mod, qkv)
+        out = get_result_with_relax_cutlass_offload(mod, qkv, 
num_final_bindings=3)
     else:
-        out = get_result_with_relax_cutlass_offload(mod, qkv, bias)
+        out = get_result_with_relax_cutlass_offload(mod, qkv, bias, 
num_final_bindings=3)
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
@@ -966,8 +968,8 @@ def test_attention_rewrite_offload(attention_rewrite_size):
         expected_out = build_and_run(expected_mod, [q, k, v], "cuda")
         tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5, 
atol=1e-5)
     else:
-        original_out = build_and_run(original_mod, [q, k, v, bias], "cuda")
-        expected_out = build_and_run(expected_mod, [q, k, v, bias], "cuda")
+        original_out = build_and_run(original_mod, [q, k, v, bias], "cuda", 
legalize=False)
+        expected_out = build_and_run(expected_mod, [q, k, v, bias], "cuda", 
legalize=False)
         tvm.testing.assert_allclose(original_out, expected_out, rtol=1e-5, 
atol=1e-5)
 
 
@@ -1043,7 +1045,7 @@ def test_layer_norm(data_shape, dtype, axes):
     gamma = np.random.randn(data_shape[-1]).astype(dtype)
     beta = np.random.randn(data_shape[-1]).astype(dtype)
     out = build_and_run(mod, [inp, gamma, beta], "cuda")
-    ref = build_and_run(Module, [inp, gamma, beta], "llvm", legalize=True)
+    ref = build_and_run(Module, [inp, gamma, beta], "llvm")
 
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
diff --git a/tests/python/relax/test_transform_allocate_workspace.py 
b/tests/python/relax/test_transform_allocate_workspace.py
new file mode 100644
index 0000000000..7ffbd01b05
--- /dev/null
+++ b/tests/python/relax/test_transform_allocate_workspace.py
@@ -0,0 +1,132 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+
+
+@I.ir_module
+class Module:
+    @R.function
+    def fused_relax_nn_attention_cutlass(
+        q: R.Tensor((32, 8, 16, 8), dtype="float16"),
+        k: R.Tensor((32, 8, 16, 8), dtype="float16"),
+        v: R.Tensor((32, 8, 16, 8), dtype="float16"),
+    ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
+        R.func_attr(
+            {
+                "Codegen": "cutlass",
+                "WorkspaceSize": 65536,
+                "global_symbol": "fused_relax_nn_attention_cutlass",
+            }
+        )
+
+        @R.function
+        def gv(
+            q_1: R.Tensor((32, 8, 16, 8), dtype="float16"),
+            k_1: R.Tensor((32, 8, 16, 8), dtype="float16"),
+            v_1: R.Tensor((32, 8, 16, 8), dtype="float16"),
+        ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
+            R.func_attr({"Composite": "cutlass.attention", "Primitive": 1, 
"WorkspaceSize": 65536})
+            with R.dataflow():
+                gv_2: R.Tensor((32, 8, 16, 8), dtype="float16") = 
R.nn.attention(
+                    q_1, k_1, v_1, scale=None
+                )
+                R.output(gv_2)
+            return gv_2
+
+        gv1: R.Tensor((32, 8, 16, 8), dtype="float16") = gv(q, k, v)
+        return gv1
+
+    @R.function
+    def main(
+        q: R.Tensor((32, 8, 16, 8), dtype="float16"),
+        k: R.Tensor((32, 8, 16, 8), dtype="float16"),
+        v: R.Tensor((32, 8, 16, 8), dtype="float16"),
+    ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
+        cls = Module
+        with R.dataflow():
+            gv: R.Tensor((32, 8, 16, 8), dtype="float16") = 
cls.fused_relax_nn_attention_cutlass(
+                q, k, v
+            )
+            R.output(gv)
+        return gv
+
+
+@I.ir_module
+class Expected:
+    @R.function
+    def fused_relax_nn_attention_cutlass1(
+        q: R.Tensor((32, 8, 16, 8), dtype="float16"),
+        k: R.Tensor((32, 8, 16, 8), dtype="float16"),
+        v: R.Tensor((32, 8, 16, 8), dtype="float16"),
+        workspace: R.Tensor((65536,), dtype="uint8"),
+    ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
+        R.func_attr(
+            {
+                "Codegen": "cutlass",
+                "WorkspaceSize": 65536,
+                "global_symbol": "fused_relax_nn_attention_cutlass1",
+            }
+        )
+
+        @R.function
+        def gv(
+            q_1: R.Tensor((32, 8, 16, 8), dtype="float16"),
+            k_1: R.Tensor((32, 8, 16, 8), dtype="float16"),
+            v_1: R.Tensor((32, 8, 16, 8), dtype="float16"),
+            workspace_1: R.Tensor((65536,), dtype="uint8"),
+        ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
+            R.func_attr({"Composite": "cutlass.attention", "Primitive": 1, 
"WorkspaceSize": 65536})
+            with R.dataflow():
+                gv_2: R.Tensor((32, 8, 16, 8), dtype="float16") = 
R.nn.attention(
+                    q_1, k_1, v_1, scale=None
+                )
+                R.output(gv_2)
+            return gv_2
+
+        gv1: R.Tensor((32, 8, 16, 8), dtype="float16") = gv(q, k, v, workspace)
+        return gv1
+
+    @R.function
+    def main(
+        q: R.Tensor((32, 8, 16, 8), dtype="float16"),
+        k: R.Tensor((32, 8, 16, 8), dtype="float16"),
+        v: R.Tensor((32, 8, 16, 8), dtype="float16"),
+    ) -> R.Tensor((32, 8, 16, 8), dtype="float16"):
+        cls = Expected
+        with R.dataflow():
+            lv: R.Object = R.vm.alloc_storage(R.shape([65536]), 
R.prim_value(0), R.dtype("uint8"))
+            workspace_main: R.Tensor((65536,), dtype="uint8") = 
R.vm.alloc_tensor(
+                lv, R.prim_value(0), R.shape([65536]), R.dtype("uint8")
+            )
+            gv: R.Tensor((32, 8, 16, 8), dtype="float16") = 
cls.fused_relax_nn_attention_cutlass1(
+                q, k, v, workspace_main
+            )
+            R.output(gv)
+        return gv
+
+
+def test_single_attention():
+    rewritten = relax.transform.AllocateWorkspace()(Module)
+    tvm.ir.assert_structural_equal(rewritten, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to