This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 23146d651a [Unity][Op] Expose scale in `R.nn.attention` and add its 
legalize op (#14412)
23146d651a is described below

commit 23146d651a61221541fac4bddda805003c77d2f5
Author: Yaxing Cai <[email protected]>
AuthorDate: Mon Mar 27 22:11:12 2023 -0700

    [Unity][Op] Expose scale in `R.nn.attention` and add its legalize op 
(#14412)
    
    This PR exposes the custom scale in `R.nn.attention` and adds its
    legalize op.
---
 include/tvm/relax/attrs/nn.h                       |  10 ++
 python/tvm/contrib/cutlass/attention_operation.py  |  14 +-
 python/tvm/contrib/cutlass/build.py                |   9 +-
 python/tvm/contrib/cutlass/gen_tensor_op.py        |  11 +-
 python/tvm/contrib/cutlass/library.py              |   4 +-
 python/tvm/relax/op/nn/nn.py                       |  14 +-
 python/tvm/relax/transform/legalize_ops/nn.py      |  56 +++++++
 src/relax/op/nn/attention.cc                       |  14 +-
 src/relax/op/nn/attention.h                        |   2 +-
 tests/python/relax/test_codegen_cutlass.py         | 125 +++++++---------
 .../python/relax/test_transform_legalize_ops_nn.py | 163 +++++++++++++++++++++
 11 files changed, 337 insertions(+), 85 deletions(-)

diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 3daa32fd76..bcfe3207bc 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -295,6 +295,16 @@ struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
   }
 };  // struct DropoutAttrs
 
+/*! \brief Attributes used in dropout operator */
+struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> {
+  Optional<FloatImm> scale;
+
+  TVM_DECLARE_ATTRS(AttentionAttrs, "relax.attrs.AttentionAttrs") {
+    TVM_ATTR_FIELD(scale).describe(
+        "The custom scale applied before the softmax. The default value is 1 / 
sqrt(head_dim).");
+  }
+};  // struct AttentionAttrs
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
index 9093a03dd6..f7dee4e3b8 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -90,7 +90,7 @@ def instantiate_attention_template(attrs, func_args):
   p.head_dim_value = ${head_dim_value}; // H'
   p.num_queries = ${num_queries}; // S
   p.num_keys = ${num_keys}; // S'
-  p.scale = 1.0f / sqrt(float(${head_dim}));
+  p.scale = ${scale};
 
   // stride for N
   p.q_strideH = p.head_dim; // H
@@ -123,12 +123,12 @@ def instantiate_attention_template(attrs, func_args):
   CHECK(Attention::check_supported(p));
   kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
 """
-    if attrs["kSupportsBias"]:
-        template = substitute_template(
-            template, {"bias_template": bias_template[attrs["bias_layout"]]}
-        )
-    else:
-        template = substitute_template(template, {"bias_template": ""})
+
+    template = substitute_template(
+        template,
+        {"bias_template": bias_template[attrs["bias_layout"]] if "bias_layout" 
in attrs else ""},
+    )
+
     for i, arg in enumerate(func_args):
         attrs["arg{}".format(i)] = arg
     return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index d9eefd34a3..93d1331ac4 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -786,7 +786,12 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
     def handle_attention(self, f, op_type):
         """Tune and annotate a dense op."""
         signature = _extract_relax_function_signature(f)
-
+        if _get_call_node(f.body, "relax.nn.attention") is not None:
+            op_attrs = _get_call_node(f.body, "relax.nn.attention").attrs
+        elif _get_call_node(f.body, "relax.nn.attention_bias") is not None:
+            op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs
+        else:
+            raise ValueError(f"Cannot find call node for attention")
         q_shape = signature["arg0_shape"]
         k_shape = signature["arg1_shape"]
         v_shape = signature["arg2_shape"]
@@ -798,6 +803,7 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
         num_batches, num_queries, num_heads, head_dim = q_shape
         _, num_keys, _, _ = k_shape
         _, _, _, head_dim_value = v_shape
+        scale = op_attrs.scale
         bias = {}
         if "arg3_dtype" in signature:
             bias["arg3_dtype"] = signature["arg3_dtype"]
@@ -821,6 +827,7 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
                 "num_heads": num_heads,
                 "head_dim": head_dim,
                 "head_dim_value": head_dim_value,
+                "scale": scale,
                 "arch": self.options["sm"],
                 **bias,
             }
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 1de1938580..61c88c657f 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -17,6 +17,7 @@
 # pylint: disable=invalid-name
 """Common functions and classes for CUTLASS GEMM and Conv2d geneator."""
 import logging
+import math
 import multiprocessing
 import os
 import re
@@ -722,6 +723,12 @@ def instantiate_template(func_name, annotations, 
func_args):
             attrs["kKeysPerBlock"] = 64
             attrs["kSingleValueIteration"] = True
         attrs["output_size"] = b * s * n * h_v
+        attrs["scale"] = (
+            float(1 / math.sqrt(h.value)) if annotations["scale"] is None else 
annotations["scale"]
+        )
+        assert (
+            attrs["scale"] > 0 or attrs["scale"] < 0
+        ), "Cutlass may generate nan occasionally when scale == 0.0"
         attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
         attrs["kSupportsDropout"] = False
         if len(func_args) > 3:
@@ -735,7 +742,9 @@ def instantiate_template(func_name, annotations, func_args):
             else:
                 raise NotImplementedError()
         else:
-            attrs["kSupportsBias"] = False
+            # To support negative scale in current Cutlass implementation,
+            # kSupportsBias should be set true, or there are nan's as result.
+            attrs["kSupportsBias"] = attrs["scale"] < 0
         code = instantiate_attention_template(attrs, func_args)
         return CodegenResult(code, headers)
 
diff --git a/python/tvm/contrib/cutlass/library.py 
b/python/tvm/contrib/cutlass/library.py
index b72553ef60..ead5804b59 100644
--- a/python/tvm/contrib/cutlass/library.py
+++ b/python/tvm/contrib/cutlass/library.py
@@ -20,7 +20,7 @@ import re
 import enum
 from enum import auto as enum_auto
 
-from tvm.tir.expr import IntImm
+from tvm.tir.expr import IntImm, FloatImm
 
 
 class GeneratorTarget(enum.Enum):
@@ -147,6 +147,8 @@ def substitute_template(template, values):
         for key, value in values.items():
             if isinstance(value, (int, IntImm)):
                 value = str(int(value))
+            if isinstance(value, (float, FloatImm)):
+                value = str(float(value))
             elif isinstance(value, bool):
                 value = str(value).lower()
             regex = "\\$\\{%s\\}" % key
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index e1d41c6cdf..02468637e0 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -18,6 +18,7 @@
 from typing import List, Optional, Tuple, Union
 
 from tvm import DataType
+from tvm.tir import FloatImm
 
 from . import _ffi_api
 from ...expr import Expr
@@ -913,7 +914,13 @@ def cross_entropy_with_logits(predictions: Expr, labels: 
Expr) -> Expr:
     return _ffi_api.cross_entropy_with_logits(predictions, labels)  # type: 
ignore
 
 
-def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] = 
None) -> Expr:
+def attention(
+    query: Expr,
+    key: Expr,
+    value: Expr,
+    bias: Optional[Expr] = None,
+    scale: Optional[FloatImm] = None,
+) -> Expr:
     r"""Computes fused multi head attention.
 
     All input tensors are of 4-D tensors with BSNH layout.
@@ -943,10 +950,13 @@ def attention(query: Expr, key: Expr, value: Expr, bias: 
Optional[Expr] = None)
         (batch_size, num_head, seq_len, seq_len_kv),
         (batch_size, seq_len, seq_len_kv) or (batch_size, seq_len_kv).
 
+    scale: Optional[FloatImm]
+        The custom scale applied before the softmax. The default value is 1 / 
sqrt(head_dim).
+
     Returns
     -------
     result : relax.Expr
         The computed result. The layout of the output should be
         (batch_size, seq_len, num_head, head_dim_v).
     """
-    return _ffi_api.attention(query, key, value, bias)  # type: ignore
+    return _ffi_api.attention(query, key, value, bias, scale)  # type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index 889e6e0941..1ce4520635 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -312,3 +312,59 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
 def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
     logging.info("Dropout is handled by frontend translator at this moment and 
is not legalized.")
     return call
+
+
+def _te_attention(
+    q: te.Tensor, k: te.Tensor, v: te.Tensor, bias: te.Tensor, scale: 
tir.FloatImm
+) -> te.Tensor:
+    batch_size, seq_len, num_head, head_dim = q.shape
+    _, seq_len_kv, _, head_dim_v = v.shape
+    q = topi.transpose(q, [0, 2, 1, 3])
+    k = topi.transpose(k, [0, 2, 1, 3])
+    v = topi.transpose(v, [0, 2, 1, 3])
+    q = topi.reshape(q, [batch_size * num_head, seq_len, head_dim])
+    k = topi.reshape(k, [batch_size * num_head, seq_len_kv, head_dim])
+    v = topi.reshape(v, [batch_size * num_head, seq_len_kv, head_dim_v])
+    p = topi.nn.batch_matmul(q, k)
+    if scale is not None:
+        p = topi.multiply(p, scale)
+    else:
+        p = topi.divide(p, tir.sqrt(tir.Cast(p.dtype, head_dim)))
+    if bias is not None:
+        p = topi.reshape(p, [batch_size, num_head, seq_len, seq_len_kv])
+        if len(bias.shape) == 2:
+            bias = topi.reshape(bias, [batch_size, 1, 1, seq_len_kv])
+        elif len(bias.shape) == 3:
+            bias = topi.reshape(bias, [batch_size, 1, seq_len, seq_len_kv])
+        p = topi.add(p, bias)
+        p = topi.reshape(p, [batch_size * num_head, seq_len, seq_len_kv])
+    s = topi.nn.softmax(p)
+    o = topi.nn.batch_matmul(s, v, transpose_b=False)
+    o = topi.reshape(o, [batch_size, num_head, seq_len, head_dim_v])
+    return topi.transpose(o, [0, 2, 1, 3])
+
+
+@register_legalize("relax.nn.attention")
+def _nn_attention(bb: BlockBuilder, call: Call) -> Expr:
+    return bb.call_te(
+        _te_attention,
+        call.args[0],
+        call.args[1],
+        call.args[2],
+        None,
+        call.attrs.scale,
+        primfunc_name_hint="attention",
+    )
+
+
+@register_legalize("relax.nn.attention_bias")
+def _nn_attention_bias(bb: BlockBuilder, call: Call) -> Expr:
+    return bb.call_te(
+        _te_attention,
+        call.args[0],
+        call.args[1],
+        call.args[2],
+        call.args[3],
+        call.attrs.scale,
+        primfunc_name_hint="attention_bias",
+    )
diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc
index e139aa09d6..c27e8b68d0 100644
--- a/src/relax/op/nn/attention.cc
+++ b/src/relax/op/nn/attention.cc
@@ -26,14 +26,18 @@ namespace tvm {
 namespace relax {
 
 /* relax.nn.attention */
-Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias) {
+TVM_REGISTER_NODE_TYPE(AttentionAttrs);
+
+Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias, 
Optional<FloatImm> scale) {
+  ObjectPtr<AttentionAttrs> attrs = make_object<AttentionAttrs>();
+  attrs->scale = scale;
   if (bias.defined()) {
     return Call(Op::Get("relax.nn.attention_bias"),
-                {std::move(query), std::move(key), std::move(value), 
std::move(bias.value())}, {},
-                {});
+                {std::move(query), std::move(key), std::move(value), 
std::move(bias.value())},
+                Attrs(attrs), {});
   }
   return Call(Op::Get("relax.nn.attention"), {std::move(query), 
std::move(key), std::move(value)},
-              {}, {});
+              Attrs(attrs), {});
 }
 
 TVM_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention);
@@ -105,6 +109,7 @@ StructInfo InferStructInfoAttention(const Call& call, const 
BlockBuilder& ctx) {
 }
 
 TVM_REGISTER_OP("relax.nn.attention")
+    .set_attrs_type<AttentionAttrs>()
     .set_num_inputs(3)
     .add_argument("query", "Tensor", "The input queries tensor.")
     .add_argument("key", "Tensor", "The input keys tensor.")
@@ -112,6 +117,7 @@ TVM_REGISTER_OP("relax.nn.attention")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAttention);
 
 TVM_REGISTER_OP("relax.nn.attention_bias")
+    .set_attrs_type<AttentionAttrs>()
     .set_num_inputs(4)
     .add_argument("query", "Tensor", "The input queries tensor.")
     .add_argument("key", "Tensor", "The input keys tensor.")
diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h
index 662e0b7e7b..7eda30b408 100644
--- a/src/relax/op/nn/attention.h
+++ b/src/relax/op/nn/attention.h
@@ -33,7 +33,7 @@ namespace tvm {
 namespace relax {
 
 /*! \brief fused multi head attention */
-Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias);
+Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias, 
Optional<FloatImm> scale);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 5ea9a9d040..c8ca44311d 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -566,11 +566,14 @@ def attention_size(request):
     return request.param
 
 
-def get_relax_attention_module(q, k, v, bias=None):
+def get_relax_attention_module(q, k, v, bias=None, qk_scale=None):
     dtype = str(q.dtype)
 
     from tvm.script.ir_builder import IRBuilder
-    from tvm.script.ir_builder import relax as relax_builder
+    from tvm.script.ir_builder import relax as relax_builder, tir as T
+
+    if qk_scale is not None:
+        qk_scale = T.FloatImm("float32", qk_scale)
 
     with IRBuilder() as builder:
         with relax_builder.function():
@@ -581,7 +584,7 @@ def get_relax_attention_module(q, k, v, bias=None):
             if bias is not None:
                 bias = R.arg("bias", R.Tensor(bias.shape, dtype))
             with R.dataflow() as frame:
-                result = R.emit(R.nn.attention(q, k, v, bias))
+                result = R.emit(R.nn.attention(q, k, v, bias, qk_scale))
                 R.output(result)
 
             R.func_ret_value(frame.output_vars[0])
@@ -591,22 +594,32 @@ def get_relax_attention_module(q, k, v, bias=None):
 
 
 @memoize("topi.tests.test_codegen_cutlass.test_attention_offload")
-def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, dtype):
+def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, 
qk_scale, dtype):
     q = np.random.randn(b, s, n, h).astype(dtype)
     k = np.random.randn(b, s_kv, n, h).astype(dtype)
     v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
     qt = q.transpose(0, 2, 1, 3)  # b, n, s, h
     kt = k.transpose(0, 2, 3, 1)  # b, n, h, s_kv
-    score = qt @ kt / np.sqrt(q.shape[-1])  # b, n, s, s_kv
+    if not qk_scale == "none":
+        score = qt @ kt * qk_scale  # b, n, s, s_kv
+    else:
+        score = qt @ kt / np.sqrt(q.shape[-1])  # b, n, s, s_kv
+    if not bias_shape == "none":
+        bias = np.random.randn(*bias_shape).astype(dtype)
+        score = score + bias.reshape(*bias_reshape)  # b, n, s, s_kv
+    else:
+        bias = None
     attn = tvm.topi.testing.softmax_python(score, -1)
     vt = v.transpose(0, 2, 1, 3)  # b, n, s_kv, h_v
     ref = attn @ vt  # b, n, s, h_v
-    return q, k, v, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
+    return q, k, v, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
 
 
 def test_attention_offload(attention_size, attention_dtype):
     b, (s, s_kv), n, (h, h_v) = attention_size
-    q, k, v, ref = get_numpy_attention_ref(b, s, s_kv, n, h, h_v, 
attention_dtype)
+    q, k, v, _, ref = get_numpy_attention_ref(
+        b, s, s_kv, n, h, h_v, "none", "none", "none", attention_dtype
+    )
 
     mod = get_relax_attention_module(q, k, v)
     out = get_result_with_relax_cutlass_offload(mod, q, k, v)
@@ -614,25 +627,23 @@ def test_attention_offload(attention_size, 
attention_dtype):
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
-@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_4d_offload")
-def get_numpy_attention_bias_4d_ref(b, s, s_kv, n, h, h_v, dtype):
-    q = np.random.randn(b, s, n, h).astype(dtype)
-    k = np.random.randn(b, s_kv, n, h).astype(dtype)
-    v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
-    bias = np.random.randn(b, n, s, s_kv).astype(dtype)
-    qt = q.transpose(0, 2, 1, 3)  # b, n, s, h
-    kt = k.transpose(0, 2, 3, 1)  # b, n, h, s_kv
-    score = qt @ kt / np.sqrt(q.shape[-1])  # b, n, s, s_kv
-    score_bias = score + bias  # b, n, s, s_kv
-    attn = tvm.topi.testing.softmax_python(score_bias, -1)
-    vt = v.transpose(0, 2, 1, 3)  # b, n, s_kv, h_v
-    ref = attn @ vt  # b, n, s, h_v
-    return q, k, v, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
[email protected](
+    params=[
+        # B, S, N, H, bias_shape, bias_reshape
+        (4, (16, 8), 32, (8, 16), (4, 32, 16, 8), (4, 32, 16, 8)),
+        (4, (16, 8), 32, (8, 16), (4, 16, 8), (4, 1, 16, 8)),
+        (4, (16, 8), 32, (8, 16), (4, 8), (4, 1, 1, 8)),
+    ]
+)
+def attention_bias_size(request):
+    return request.param
 
 
-def test_attention_bias_4d_offload(attention_size, attention_dtype):
-    b, (s, s_kv), n, (h, h_v) = attention_size
-    q, k, v, bias, ref = get_numpy_attention_bias_4d_ref(b, s, s_kv, n, h, 
h_v, attention_dtype)
+def test_attention_bias_offload(attention_bias_size, attention_dtype):
+    b, (s, s_kv), n, (h, h_v), bias_shape, bias_reshape = attention_bias_size
+    q, k, v, bias, ref = get_numpy_attention_ref(
+        b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, "none", 
attention_dtype
+    )
 
     mod = get_relax_attention_module(q, k, v, bias)
     out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
@@ -640,55 +651,33 @@ def test_attention_bias_4d_offload(attention_size, 
attention_dtype):
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
-@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_3d_offload")
-def get_numpy_attention_bias_3d_ref(b, s, s_kv, n, h, h_v, dtype):
-    q = np.random.randn(b, s, n, h).astype(dtype)
-    k = np.random.randn(b, s_kv, n, h).astype(dtype)
-    v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
-    bias = np.random.randn(b, s, s_kv).astype(dtype)
-    qt = q.transpose(0, 2, 1, 3)  # b, n, s, h
-    kt = k.transpose(0, 2, 3, 1)  # b, n, h, s_kv
-    score = qt @ kt / np.sqrt(q.shape[-1])  # b, n, s, s_kv
-    score_bias = score + bias.reshape(b, 1, s, s_kv)  # b, n, s, s_kv
-    attn = tvm.topi.testing.softmax_python(score_bias, -1)
-    vt = v.transpose(0, 2, 1, 3)  # b, n, s_kv, h_v
-    ref = attn @ vt  # b, n, s, h_v
-    return q, k, v, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
-
-
-def test_attention_bias_3d_offload(attention_size, attention_dtype):
-    b, (s, s_kv), n, (h, h_v) = attention_size
-    q, k, v, bias, ref = get_numpy_attention_bias_3d_ref(b, s, s_kv, n, h, 
h_v, attention_dtype)
-
-    mod = get_relax_attention_module(q, k, v, bias)
-    out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
-
-    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
-
[email protected](
+    params=[
+        # B, S, N, H, bias_shape, bias_reshape
+        (4, (16, 8), 32, (8, 16), (4, 32, 16, 8), (4, 32, 16, 8)),
+        (4, (16, 8), 32, (8, 16), "none", "none"),
+    ]
+)
+def attention_scale_size(request):
+    return request.param
 
-@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_2d_offload")
-def get_numpy_attention_bias_2d_ref(b, s, s_kv, n, h, h_v, dtype):
-    q = np.random.randn(b, s, n, h).astype(dtype)
-    k = np.random.randn(b, s_kv, n, h).astype(dtype)
-    v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
-    bias = np.random.randn(b, s_kv).astype(dtype)
-    qt = q.transpose(0, 2, 1, 3)  # b, n, s, h
-    kt = k.transpose(0, 2, 3, 1)  # b, n, h, s_kv
-    score = qt @ kt / np.sqrt(q.shape[-1])  # b, n, s, s_kv
-    score_bias = score + bias.reshape(b, 1, 1, s_kv)  # b, n, s, s_kv
-    attn = tvm.topi.testing.softmax_python(score_bias, -1)
-    vt = v.transpose(0, 2, 1, 3)  # b, n, s_kv, h_v
-    ref = attn @ vt  # b, n, s, h_v
-    return q, k, v, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
 
[email protected](params=[0.01, 1e-8, -0.5, 1.23])
+def attention_scale(request):
+    return request.param
 
-def test_attention_bias_2d_offload(attention_size, attention_dtype):
-    b, (s, s_kv), n, (h, h_v) = attention_size
-    q, k, v, bias, ref = get_numpy_attention_bias_2d_ref(b, s, s_kv, n, h, 
h_v, attention_dtype)
 
-    mod = get_relax_attention_module(q, k, v, bias)
-    out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
+def test_attention_scale_offload(attention_scale_size, attention_scale, 
attention_dtype):
+    b, (s, s_kv), n, (h, h_v), bias_shape, bias_reshape = attention_scale_size
+    q, k, v, bias, ref = get_numpy_attention_ref(
+        b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, attention_scale, 
attention_dtype
+    )
 
+    mod = get_relax_attention_module(q, k, v, bias, attention_scale)
+    if bias is None:
+        out = get_result_with_relax_cutlass_offload(mod, q, k, v)
+    else:
+        out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index e944b8d76e..e807082e35 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2280,5 +2280,168 @@ def test_group_norm_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_attention():
+    # fmt: off
+    @tvm.script.ir_module
+    class Attention:
+        @R.function
+        def main(q: R.Tensor((4, 16, 32, 8), "float32"), k: R.Tensor((4, 8, 
32, 8), "float32"), v: R.Tensor((4, 8, 32, 16), "float32"), bias: R.Tensor((4, 
32, 16, 8), "float32")):
+            scale = T.FloatImm("float32", 0.1)
+            gv: R.Tensor((4, 16, 32, 16), "float32") = R.nn.attention(q, k, v, 
bias, scale)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def attention_bias(rxplaceholder: T.Buffer((T.int64(4), T.int64(16), 
T.int64(32), T.int64(8)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4), 
T.int64(8), T.int64(32), T.int64(8)), "float32"), rxplaceholder_2: 
T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(16)), "float32"), 
rxplaceholder_3: T.Buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8)), 
"float32"), T_transpose: T.Buffer((T.int64(4), T.int64(16), T.int64(32), 
T.int64(16)), "float32")):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            T_transpose_1 = T.alloc_buffer((T.int64(4), T.int64(32), 
T.int64(16), T.int64(8)))
+            T_reshape = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8)))
+            T_transpose_2 = T.alloc_buffer((T.int64(4), T.int64(32), 
T.int64(8), T.int64(8)))
+            T_reshape_1 = T.alloc_buffer((T.int64(128), T.int64(8), 
T.int64(8)))
+            T_batch_matmul_NT = T.alloc_buffer((T.int64(128), T.int64(16), 
T.int64(8)))
+            T_multiply = T.alloc_buffer((T.int64(128), T.int64(16), 
T.int64(8)))
+            T_reshape_2 = T.alloc_buffer((T.int64(4), T.int64(32), 
T.int64(16), T.int64(8)))
+            T_add = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16), 
T.int64(8)))
+            T_reshape_3 = T.alloc_buffer((T.int64(128), T.int64(16), 
T.int64(8)))
+            T_softmax_maxelem = T.alloc_buffer((T.int64(128), T.int64(16)))
+            T_softmax_exp = T.alloc_buffer((T.int64(128), T.int64(16), 
T.int64(8)))
+            T_softmax_expsum = T.alloc_buffer((T.int64(128), T.int64(16)))
+            T_softmax_norm = T.alloc_buffer((T.int64(128), T.int64(16), 
T.int64(8)))
+            T_transpose_3 = T.alloc_buffer((T.int64(4), T.int64(32), 
T.int64(8), T.int64(16)))
+            T_reshape_4 = T.alloc_buffer((T.int64(128), T.int64(8), 
T.int64(16)))
+            T_batch_matmul_NN = T.alloc_buffer((T.int64(128), T.int64(16), 
T.int64(16)))
+            T_reshape_5 = T.alloc_buffer((T.int64(4), T.int64(32), 
T.int64(16), T.int64(16)))
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), 
T.int64(16), T.int64(8)):
+                with T.block("T_transpose"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder[v_ax0, v_ax2, v_ax1, v_ax3])
+                    T.writes(T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder[v_ax0, v_ax2, v_ax1, v_ax3]
+            for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
+                with T.block("T_reshape"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) // 
T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + 
v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % 
T.int64(16), v_ax2 % T.int64(8)])
+                    T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
+                    T_reshape[v_ax0, v_ax1, v_ax2] = T_transpose_1[((v_ax2 // 
T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), 
((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // 
T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)]
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), 
T.int64(8), T.int64(8)):
+                with T.block("T_transpose_1"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder_1[v_ax0, v_ax2, v_ax1, v_ax3])
+                    T.writes(T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder_1[v_ax0, v_ax2, v_ax1, v_ax3]
+            for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(8)):
+                with T.block("T_reshape_1"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) // 
T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + 
v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % 
T.int64(8), v_ax2 % T.int64(8)])
+                    T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2])
+                    T_reshape_1[v_ax0, v_ax1, v_ax2] = T_transpose_2[((v_ax2 
// T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), 
((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // 
T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]
+            for b, i, j, k in T.grid(T.int64(128), T.int64(16), T.int64(8), 
T.int64(8)):
+                with T.block("T_batch_matmul_NT"):
+                    v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k])
+                    T.reads(T_reshape[v_b, v_i, v_k], T_reshape_1[v_b, v_j, 
v_k])
+                    T.writes(T_batch_matmul_NT[v_b, v_i, v_j])
+                    T.block_attr({"layout_free_placeholders": [T_reshape_1]})
+                    with T.init():
+                        T_batch_matmul_NT[v_b, v_i, v_j] = T.float32(0)
+                    T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[v_b, 
v_i, v_j] + T_reshape[v_b, v_i, v_k] * T_reshape_1[v_b, v_j, v_k]
+            for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(T_batch_matmul_NT[v_ax0, v_ax1, v_ax2])
+                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
+                    T_multiply[v_ax0, v_ax1, v_ax2] = T_batch_matmul_NT[v_ax0, 
v_ax1, v_ax2] * T.float32(0.10000000000000001)
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), 
T.int64(16), T.int64(8)):
+                with T.block("T_reshape_2"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(T_multiply[(v_ax0 * T.int64(32) + (v_ax3 // 
T.int64(8) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // 
T.int64(8) + v_ax2) % T.int64(16), v_ax3 % T.int64(8)])
+                    T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_multiply[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(16) 
+ v_ax1) % T.int64(128), (v_ax3 // T.int64(8) + v_ax2) % T.int64(16), v_ax3 % 
T.int64(8)]
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), 
T.int64(16), T.int64(8)):
+                with T.block("T_add"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3], 
rxplaceholder_3[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_2[v_ax0, 
v_ax1, v_ax2, v_ax3] + rxplaceholder_3[v_ax0, v_ax1, v_ax2, v_ax3]
+            for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
+                with T.block("T_reshape_3"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(T_add[((v_ax2 // T.int64(8) + v_ax1) // 
T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) + 
v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) % 
T.int64(16), v_ax2 % T.int64(8)])
+                    T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2])
+                    T_reshape_3[v_ax0, v_ax1, v_ax2] = T_add[((v_ax2 // 
T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), 
((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // 
T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)]
+            for i0, i1, k in T.grid(T.int64(128), T.int64(16), T.int64(8)):
+                with T.block("T_softmax_maxelem"):
+                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+                    T.reads(T_reshape_3[v_i0, v_i1, v_k])
+                    T.writes(T_softmax_maxelem[v_i0, v_i1])
+                    with T.init():
+                        T_softmax_maxelem[v_i0, v_i1] = 
T.float32(-3.4028234663852886e+38)
+                    T_softmax_maxelem[v_i0, v_i1] = 
T.max(T_softmax_maxelem[v_i0, v_i1], T_reshape_3[v_i0, v_i1, v_k])
+            for i0, i1, i2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
+                with T.block("T_softmax_exp"):
+                    v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                    T.reads(T_reshape_3[v_i0, v_i1, v_i2], 
T_softmax_maxelem[v_i0, v_i1])
+                    T.writes(T_softmax_exp[v_i0, v_i1, v_i2])
+                    T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(T_reshape_3[v_i0, 
v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1])
+            for i0, i1, k in T.grid(T.int64(128), T.int64(16), T.int64(8)):
+                with T.block("T_softmax_expsum"):
+                    v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+                    T.reads(T_softmax_exp[v_i0, v_i1, v_k])
+                    T.writes(T_softmax_expsum[v_i0, v_i1])
+                    with T.init():
+                        T_softmax_expsum[v_i0, v_i1] = T.float32(0)
+                    T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, 
v_i1] + T_softmax_exp[v_i0, v_i1, v_k]
+            for i0, i1, i2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
+                with T.block("T_softmax_norm"):
+                    v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                    T.reads(T_softmax_exp[v_i0, v_i1, v_i2], 
T_softmax_expsum[v_i0, v_i1])
+                    T.writes(T_softmax_norm[v_i0, v_i1, v_i2])
+                    T.block_attr({"axis": 2})
+                    T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, 
v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1]
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), 
T.int64(8), T.int64(16)):
+                with T.block("T_transpose_2"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(rxplaceholder_2[v_ax0, v_ax2, v_ax1, v_ax3])
+                    T.writes(T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3] = 
rxplaceholder_2[v_ax0, v_ax2, v_ax1, v_ax3]
+            for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(16)):
+                with T.block("T_reshape_4"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(T_transpose_3[((v_ax2 // T.int64(16) + v_ax1) // 
T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(16) + 
v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(16) + v_ax1) % 
T.int64(8), v_ax2 % T.int64(16)])
+                    T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2])
+                    T_reshape_4[v_ax0, v_ax1, v_ax2] = T_transpose_3[((v_ax2 
// T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), 
((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // 
T.int64(16) + v_ax1) % T.int64(8), v_ax2 % T.int64(16)]
+            for b, i, j, k in T.grid(T.int64(128), T.int64(16), T.int64(16), 
T.int64(8)):
+                with T.block("T_batch_matmul_NN"):
+                    v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k])
+                    T.reads(T_softmax_norm[v_b, v_i, v_k], T_reshape_4[v_b, 
v_k, v_j])
+                    T.writes(T_batch_matmul_NN[v_b, v_i, v_j])
+                    T.block_attr({"layout_free_placeholders": [T_reshape_4]})
+                    with T.init():
+                        T_batch_matmul_NN[v_b, v_i, v_j] = T.float32(0)
+                    T_batch_matmul_NN[v_b, v_i, v_j] = T_batch_matmul_NN[v_b, 
v_i, v_j] + T_softmax_norm[v_b, v_i, v_k] * T_reshape_4[v_b, v_k, v_j]
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32), 
T.int64(16), T.int64(16)):
+                with T.block("T_reshape_5"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(T_batch_matmul_NN[(v_ax0 * T.int64(32) + (v_ax3 // 
T.int64(16) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // 
T.int64(16) + v_ax2) % T.int64(16), v_ax3 % T.int64(16)])
+                    T.writes(T_reshape_5[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_reshape_5[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_batch_matmul_NN[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(16) + v_ax2) // 
T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(16) + v_ax2) % 
T.int64(16), v_ax3 % T.int64(16)]
+            for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(16), 
T.int64(32), T.int64(16)):
+                with T.block("T_transpose_3"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(T_reshape_5[v_ax0, v_ax2, v_ax1, v_ax3])
+                    T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_reshape_5[v_ax0, v_ax2, v_ax1, v_ax3]
+
+        @R.function
+        def main(q: R.Tensor((4, 16, 32, 8), dtype="float32"), k: R.Tensor((4, 
8, 32, 8), dtype="float32"), v: R.Tensor((4, 8, 32, 16), dtype="float32"), 
bias: R.Tensor((4, 32, 16, 8), dtype="float32")) -> R.Tensor((4, 16, 32, 16), 
dtype="float32"):
+            gv = R.call_tir(Expected.attention_bias, (q, k, v, bias), 
out_sinfo=R.Tensor((4, 16, 32, 16), dtype="float32"))
+            return gv
+    # fmt: on
+
+    mod = LegalizeOps()(Attention)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()


Reply via email to