This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 23146d651a [Unity][Op] Expose scale in `R.nn.attention` and add its
legalize op (#14412)
23146d651a is described below
commit 23146d651a61221541fac4bddda805003c77d2f5
Author: Yaxing Cai <[email protected]>
AuthorDate: Mon Mar 27 22:11:12 2023 -0700
[Unity][Op] Expose scale in `R.nn.attention` and add its legalize op
(#14412)
This PR exposes the custom scale in `R.nn.attention` and adds its
legalize op.
---
include/tvm/relax/attrs/nn.h | 10 ++
python/tvm/contrib/cutlass/attention_operation.py | 14 +-
python/tvm/contrib/cutlass/build.py | 9 +-
python/tvm/contrib/cutlass/gen_tensor_op.py | 11 +-
python/tvm/contrib/cutlass/library.py | 4 +-
python/tvm/relax/op/nn/nn.py | 14 +-
python/tvm/relax/transform/legalize_ops/nn.py | 56 +++++++
src/relax/op/nn/attention.cc | 14 +-
src/relax/op/nn/attention.h | 2 +-
tests/python/relax/test_codegen_cutlass.py | 125 +++++++---------
.../python/relax/test_transform_legalize_ops_nn.py | 163 +++++++++++++++++++++
11 files changed, 337 insertions(+), 85 deletions(-)
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 3daa32fd76..bcfe3207bc 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -295,6 +295,16 @@ struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
}
}; // struct DropoutAttrs
+/*! \brief Attributes used in dropout operator */
+struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> {
+ Optional<FloatImm> scale;
+
+ TVM_DECLARE_ATTRS(AttentionAttrs, "relax.attrs.AttentionAttrs") {
+ TVM_ATTR_FIELD(scale).describe(
+ "The custom scale applied before the softmax. The default value is 1 /
sqrt(head_dim).");
+ }
+}; // struct AttentionAttrs
+
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/contrib/cutlass/attention_operation.py
b/python/tvm/contrib/cutlass/attention_operation.py
index 9093a03dd6..f7dee4e3b8 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -90,7 +90,7 @@ def instantiate_attention_template(attrs, func_args):
p.head_dim_value = ${head_dim_value}; // H'
p.num_queries = ${num_queries}; // S
p.num_keys = ${num_keys}; // S'
- p.scale = 1.0f / sqrt(float(${head_dim}));
+ p.scale = ${scale};
// stride for N
p.q_strideH = p.head_dim; // H
@@ -123,12 +123,12 @@ def instantiate_attention_template(attrs, func_args):
CHECK(Attention::check_supported(p));
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);
"""
- if attrs["kSupportsBias"]:
- template = substitute_template(
- template, {"bias_template": bias_template[attrs["bias_layout"]]}
- )
- else:
- template = substitute_template(template, {"bias_template": ""})
+
+ template = substitute_template(
+ template,
+ {"bias_template": bias_template[attrs["bias_layout"]] if "bias_layout"
in attrs else ""},
+ )
+
for i, arg in enumerate(func_args):
attrs["arg{}".format(i)] = arg
return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index d9eefd34a3..93d1331ac4 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -786,7 +786,12 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
def handle_attention(self, f, op_type):
"""Tune and annotate a dense op."""
signature = _extract_relax_function_signature(f)
-
+ if _get_call_node(f.body, "relax.nn.attention") is not None:
+ op_attrs = _get_call_node(f.body, "relax.nn.attention").attrs
+ elif _get_call_node(f.body, "relax.nn.attention_bias") is not None:
+ op_attrs = _get_call_node(f.body, "relax.nn.attention_bias").attrs
+ else:
+ raise ValueError(f"Cannot find call node for attention")
q_shape = signature["arg0_shape"]
k_shape = signature["arg1_shape"]
v_shape = signature["arg2_shape"]
@@ -798,6 +803,7 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
num_batches, num_queries, num_heads, head_dim = q_shape
_, num_keys, _, _ = k_shape
_, _, _, head_dim_value = v_shape
+ scale = op_attrs.scale
bias = {}
if "arg3_dtype" in signature:
bias["arg3_dtype"] = signature["arg3_dtype"]
@@ -821,6 +827,7 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
"num_heads": num_heads,
"head_dim": head_dim,
"head_dim_value": head_dim_value,
+ "scale": scale,
"arch": self.options["sm"],
**bias,
}
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 1de1938580..61c88c657f 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -17,6 +17,7 @@
# pylint: disable=invalid-name
"""Common functions and classes for CUTLASS GEMM and Conv2d geneator."""
import logging
+import math
import multiprocessing
import os
import re
@@ -722,6 +723,12 @@ def instantiate_template(func_name, annotations,
func_args):
attrs["kKeysPerBlock"] = 64
attrs["kSingleValueIteration"] = True
attrs["output_size"] = b * s * n * h_v
+ attrs["scale"] = (
+ float(1 / math.sqrt(h.value)) if annotations["scale"] is None else
annotations["scale"]
+ )
+ assert (
+ attrs["scale"] > 0 or attrs["scale"] < 0
+ ), "Cutlass may generate nan occasionally when scale == 0.0"
attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
attrs["kSupportsDropout"] = False
if len(func_args) > 3:
@@ -735,7 +742,9 @@ def instantiate_template(func_name, annotations, func_args):
else:
raise NotImplementedError()
else:
- attrs["kSupportsBias"] = False
+ # To support negative scale in current Cutlass implementation,
+ # kSupportsBias should be set true, or there are nan's as result.
+ attrs["kSupportsBias"] = attrs["scale"] < 0
code = instantiate_attention_template(attrs, func_args)
return CodegenResult(code, headers)
diff --git a/python/tvm/contrib/cutlass/library.py
b/python/tvm/contrib/cutlass/library.py
index b72553ef60..ead5804b59 100644
--- a/python/tvm/contrib/cutlass/library.py
+++ b/python/tvm/contrib/cutlass/library.py
@@ -20,7 +20,7 @@ import re
import enum
from enum import auto as enum_auto
-from tvm.tir.expr import IntImm
+from tvm.tir.expr import IntImm, FloatImm
class GeneratorTarget(enum.Enum):
@@ -147,6 +147,8 @@ def substitute_template(template, values):
for key, value in values.items():
if isinstance(value, (int, IntImm)):
value = str(int(value))
+ if isinstance(value, (float, FloatImm)):
+ value = str(float(value))
elif isinstance(value, bool):
value = str(value).lower()
regex = "\\$\\{%s\\}" % key
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index e1d41c6cdf..02468637e0 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -18,6 +18,7 @@
from typing import List, Optional, Tuple, Union
from tvm import DataType
+from tvm.tir import FloatImm
from . import _ffi_api
from ...expr import Expr
@@ -913,7 +914,13 @@ def cross_entropy_with_logits(predictions: Expr, labels:
Expr) -> Expr:
return _ffi_api.cross_entropy_with_logits(predictions, labels) # type:
ignore
-def attention(query: Expr, key: Expr, value: Expr, bias: Optional[Expr] =
None) -> Expr:
+def attention(
+ query: Expr,
+ key: Expr,
+ value: Expr,
+ bias: Optional[Expr] = None,
+ scale: Optional[FloatImm] = None,
+) -> Expr:
r"""Computes fused multi head attention.
All input tensors are of 4-D tensors with BSNH layout.
@@ -943,10 +950,13 @@ def attention(query: Expr, key: Expr, value: Expr, bias:
Optional[Expr] = None)
(batch_size, num_head, seq_len, seq_len_kv),
(batch_size, seq_len, seq_len_kv) or (batch_size, seq_len_kv).
+ scale: Optional[FloatImm]
+ The custom scale applied before the softmax. The default value is 1 /
sqrt(head_dim).
+
Returns
-------
result : relax.Expr
The computed result. The layout of the output should be
(batch_size, seq_len, num_head, head_dim_v).
"""
- return _ffi_api.attention(query, key, value, bias) # type: ignore
+ return _ffi_api.attention(query, key, value, bias, scale) # type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index 889e6e0941..1ce4520635 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -312,3 +312,59 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
logging.info("Dropout is handled by frontend translator at this moment and
is not legalized.")
return call
+
+
+def _te_attention(
+ q: te.Tensor, k: te.Tensor, v: te.Tensor, bias: te.Tensor, scale:
tir.FloatImm
+) -> te.Tensor:
+ batch_size, seq_len, num_head, head_dim = q.shape
+ _, seq_len_kv, _, head_dim_v = v.shape
+ q = topi.transpose(q, [0, 2, 1, 3])
+ k = topi.transpose(k, [0, 2, 1, 3])
+ v = topi.transpose(v, [0, 2, 1, 3])
+ q = topi.reshape(q, [batch_size * num_head, seq_len, head_dim])
+ k = topi.reshape(k, [batch_size * num_head, seq_len_kv, head_dim])
+ v = topi.reshape(v, [batch_size * num_head, seq_len_kv, head_dim_v])
+ p = topi.nn.batch_matmul(q, k)
+ if scale is not None:
+ p = topi.multiply(p, scale)
+ else:
+ p = topi.divide(p, tir.sqrt(tir.Cast(p.dtype, head_dim)))
+ if bias is not None:
+ p = topi.reshape(p, [batch_size, num_head, seq_len, seq_len_kv])
+ if len(bias.shape) == 2:
+ bias = topi.reshape(bias, [batch_size, 1, 1, seq_len_kv])
+ elif len(bias.shape) == 3:
+ bias = topi.reshape(bias, [batch_size, 1, seq_len, seq_len_kv])
+ p = topi.add(p, bias)
+ p = topi.reshape(p, [batch_size * num_head, seq_len, seq_len_kv])
+ s = topi.nn.softmax(p)
+ o = topi.nn.batch_matmul(s, v, transpose_b=False)
+ o = topi.reshape(o, [batch_size, num_head, seq_len, head_dim_v])
+ return topi.transpose(o, [0, 2, 1, 3])
+
+
+@register_legalize("relax.nn.attention")
+def _nn_attention(bb: BlockBuilder, call: Call) -> Expr:
+ return bb.call_te(
+ _te_attention,
+ call.args[0],
+ call.args[1],
+ call.args[2],
+ None,
+ call.attrs.scale,
+ primfunc_name_hint="attention",
+ )
+
+
+@register_legalize("relax.nn.attention_bias")
+def _nn_attention_bias(bb: BlockBuilder, call: Call) -> Expr:
+ return bb.call_te(
+ _te_attention,
+ call.args[0],
+ call.args[1],
+ call.args[2],
+ call.args[3],
+ call.attrs.scale,
+ primfunc_name_hint="attention_bias",
+ )
diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc
index e139aa09d6..c27e8b68d0 100644
--- a/src/relax/op/nn/attention.cc
+++ b/src/relax/op/nn/attention.cc
@@ -26,14 +26,18 @@ namespace tvm {
namespace relax {
/* relax.nn.attention */
-Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias) {
+TVM_REGISTER_NODE_TYPE(AttentionAttrs);
+
+Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias,
Optional<FloatImm> scale) {
+ ObjectPtr<AttentionAttrs> attrs = make_object<AttentionAttrs>();
+ attrs->scale = scale;
if (bias.defined()) {
return Call(Op::Get("relax.nn.attention_bias"),
- {std::move(query), std::move(key), std::move(value),
std::move(bias.value())}, {},
- {});
+ {std::move(query), std::move(key), std::move(value),
std::move(bias.value())},
+ Attrs(attrs), {});
}
return Call(Op::Get("relax.nn.attention"), {std::move(query),
std::move(key), std::move(value)},
- {}, {});
+ Attrs(attrs), {});
}
TVM_REGISTER_GLOBAL("relax.op.nn.attention").set_body_typed(attention);
@@ -105,6 +109,7 @@ StructInfo InferStructInfoAttention(const Call& call, const
BlockBuilder& ctx) {
}
TVM_REGISTER_OP("relax.nn.attention")
+ .set_attrs_type<AttentionAttrs>()
.set_num_inputs(3)
.add_argument("query", "Tensor", "The input queries tensor.")
.add_argument("key", "Tensor", "The input keys tensor.")
@@ -112,6 +117,7 @@ TVM_REGISTER_OP("relax.nn.attention")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAttention);
TVM_REGISTER_OP("relax.nn.attention_bias")
+ .set_attrs_type<AttentionAttrs>()
.set_num_inputs(4)
.add_argument("query", "Tensor", "The input queries tensor.")
.add_argument("key", "Tensor", "The input keys tensor.")
diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h
index 662e0b7e7b..7eda30b408 100644
--- a/src/relax/op/nn/attention.h
+++ b/src/relax/op/nn/attention.h
@@ -33,7 +33,7 @@ namespace tvm {
namespace relax {
/*! \brief fused multi head attention */
-Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias);
+Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias,
Optional<FloatImm> scale);
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 5ea9a9d040..c8ca44311d 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -566,11 +566,14 @@ def attention_size(request):
return request.param
-def get_relax_attention_module(q, k, v, bias=None):
+def get_relax_attention_module(q, k, v, bias=None, qk_scale=None):
dtype = str(q.dtype)
from tvm.script.ir_builder import IRBuilder
- from tvm.script.ir_builder import relax as relax_builder
+ from tvm.script.ir_builder import relax as relax_builder, tir as T
+
+ if qk_scale is not None:
+ qk_scale = T.FloatImm("float32", qk_scale)
with IRBuilder() as builder:
with relax_builder.function():
@@ -581,7 +584,7 @@ def get_relax_attention_module(q, k, v, bias=None):
if bias is not None:
bias = R.arg("bias", R.Tensor(bias.shape, dtype))
with R.dataflow() as frame:
- result = R.emit(R.nn.attention(q, k, v, bias))
+ result = R.emit(R.nn.attention(q, k, v, bias, qk_scale))
R.output(result)
R.func_ret_value(frame.output_vars[0])
@@ -591,22 +594,32 @@ def get_relax_attention_module(q, k, v, bias=None):
@memoize("topi.tests.test_codegen_cutlass.test_attention_offload")
-def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, dtype):
+def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, bias_shape, bias_reshape,
qk_scale, dtype):
q = np.random.randn(b, s, n, h).astype(dtype)
k = np.random.randn(b, s_kv, n, h).astype(dtype)
v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
qt = q.transpose(0, 2, 1, 3) # b, n, s, h
kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv
- score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv
+ if not qk_scale == "none":
+ score = qt @ kt * qk_scale # b, n, s, s_kv
+ else:
+ score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv
+ if not bias_shape == "none":
+ bias = np.random.randn(*bias_shape).astype(dtype)
+ score = score + bias.reshape(*bias_reshape) # b, n, s, s_kv
+ else:
+ bias = None
attn = tvm.topi.testing.softmax_python(score, -1)
vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v
ref = attn @ vt # b, n, s, h_v
- return q, k, v, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
+ return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
def test_attention_offload(attention_size, attention_dtype):
b, (s, s_kv), n, (h, h_v) = attention_size
- q, k, v, ref = get_numpy_attention_ref(b, s, s_kv, n, h, h_v,
attention_dtype)
+ q, k, v, _, ref = get_numpy_attention_ref(
+ b, s, s_kv, n, h, h_v, "none", "none", "none", attention_dtype
+ )
mod = get_relax_attention_module(q, k, v)
out = get_result_with_relax_cutlass_offload(mod, q, k, v)
@@ -614,25 +627,23 @@ def test_attention_offload(attention_size,
attention_dtype):
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
-@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_4d_offload")
-def get_numpy_attention_bias_4d_ref(b, s, s_kv, n, h, h_v, dtype):
- q = np.random.randn(b, s, n, h).astype(dtype)
- k = np.random.randn(b, s_kv, n, h).astype(dtype)
- v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
- bias = np.random.randn(b, n, s, s_kv).astype(dtype)
- qt = q.transpose(0, 2, 1, 3) # b, n, s, h
- kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv
- score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv
- score_bias = score + bias # b, n, s, s_kv
- attn = tvm.topi.testing.softmax_python(score_bias, -1)
- vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v
- ref = attn @ vt # b, n, s, h_v
- return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
[email protected](
+ params=[
+ # B, S, N, H, bias_shape, bias_reshape
+ (4, (16, 8), 32, (8, 16), (4, 32, 16, 8), (4, 32, 16, 8)),
+ (4, (16, 8), 32, (8, 16), (4, 16, 8), (4, 1, 16, 8)),
+ (4, (16, 8), 32, (8, 16), (4, 8), (4, 1, 1, 8)),
+ ]
+)
+def attention_bias_size(request):
+ return request.param
-def test_attention_bias_4d_offload(attention_size, attention_dtype):
- b, (s, s_kv), n, (h, h_v) = attention_size
- q, k, v, bias, ref = get_numpy_attention_bias_4d_ref(b, s, s_kv, n, h,
h_v, attention_dtype)
+def test_attention_bias_offload(attention_bias_size, attention_dtype):
+ b, (s, s_kv), n, (h, h_v), bias_shape, bias_reshape = attention_bias_size
+ q, k, v, bias, ref = get_numpy_attention_ref(
+ b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, "none",
attention_dtype
+ )
mod = get_relax_attention_module(q, k, v, bias)
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
@@ -640,55 +651,33 @@ def test_attention_bias_4d_offload(attention_size,
attention_dtype):
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
-@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_3d_offload")
-def get_numpy_attention_bias_3d_ref(b, s, s_kv, n, h, h_v, dtype):
- q = np.random.randn(b, s, n, h).astype(dtype)
- k = np.random.randn(b, s_kv, n, h).astype(dtype)
- v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
- bias = np.random.randn(b, s, s_kv).astype(dtype)
- qt = q.transpose(0, 2, 1, 3) # b, n, s, h
- kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv
- score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv
- score_bias = score + bias.reshape(b, 1, s, s_kv) # b, n, s, s_kv
- attn = tvm.topi.testing.softmax_python(score_bias, -1)
- vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v
- ref = attn @ vt # b, n, s, h_v
- return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
-
-
-def test_attention_bias_3d_offload(attention_size, attention_dtype):
- b, (s, s_kv), n, (h, h_v) = attention_size
- q, k, v, bias, ref = get_numpy_attention_bias_3d_ref(b, s, s_kv, n, h,
h_v, attention_dtype)
-
- mod = get_relax_attention_module(q, k, v, bias)
- out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
-
- tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
-
[email protected](
+ params=[
+ # B, S, N, H, bias_shape, bias_reshape
+ (4, (16, 8), 32, (8, 16), (4, 32, 16, 8), (4, 32, 16, 8)),
+ (4, (16, 8), 32, (8, 16), "none", "none"),
+ ]
+)
+def attention_scale_size(request):
+ return request.param
-@memoize("topi.tests.test_codegen_cutlass.test_attention_bias_2d_offload")
-def get_numpy_attention_bias_2d_ref(b, s, s_kv, n, h, h_v, dtype):
- q = np.random.randn(b, s, n, h).astype(dtype)
- k = np.random.randn(b, s_kv, n, h).astype(dtype)
- v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
- bias = np.random.randn(b, s_kv).astype(dtype)
- qt = q.transpose(0, 2, 1, 3) # b, n, s, h
- kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv
- score = qt @ kt / np.sqrt(q.shape[-1]) # b, n, s, s_kv
- score_bias = score + bias.reshape(b, 1, 1, s_kv) # b, n, s, s_kv
- attn = tvm.topi.testing.softmax_python(score_bias, -1)
- vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v
- ref = attn @ vt # b, n, s, h_v
- return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
[email protected](params=[0.01, 1e-8, -0.5, 1.23])
+def attention_scale(request):
+ return request.param
-def test_attention_bias_2d_offload(attention_size, attention_dtype):
- b, (s, s_kv), n, (h, h_v) = attention_size
- q, k, v, bias, ref = get_numpy_attention_bias_2d_ref(b, s, s_kv, n, h,
h_v, attention_dtype)
- mod = get_relax_attention_module(q, k, v, bias)
- out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
+def test_attention_scale_offload(attention_scale_size, attention_scale,
attention_dtype):
+ b, (s, s_kv), n, (h, h_v), bias_shape, bias_reshape = attention_scale_size
+ q, k, v, bias, ref = get_numpy_attention_ref(
+ b, s, s_kv, n, h, h_v, bias_shape, bias_reshape, attention_scale,
attention_dtype
+ )
+ mod = get_relax_attention_module(q, k, v, bias, attention_scale)
+ if bias is None:
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v)
+ else:
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias)
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index e944b8d76e..e807082e35 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2280,5 +2280,168 @@ def test_group_norm_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_attention():
+ # fmt: off
+ @tvm.script.ir_module
+ class Attention:
+ @R.function
+ def main(q: R.Tensor((4, 16, 32, 8), "float32"), k: R.Tensor((4, 8,
32, 8), "float32"), v: R.Tensor((4, 8, 32, 16), "float32"), bias: R.Tensor((4,
32, 16, 8), "float32")):
+ scale = T.FloatImm("float32", 0.1)
+ gv: R.Tensor((4, 16, 32, 16), "float32") = R.nn.attention(q, k, v,
bias, scale)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def attention_bias(rxplaceholder: T.Buffer((T.int64(4), T.int64(16),
T.int64(32), T.int64(8)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4),
T.int64(8), T.int64(32), T.int64(8)), "float32"), rxplaceholder_2:
T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(16)), "float32"),
rxplaceholder_3: T.Buffer((T.int64(4), T.int64(32), T.int64(16), T.int64(8)),
"float32"), T_transpose: T.Buffer((T.int64(4), T.int64(16), T.int64(32),
T.int64(16)), "float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ T_transpose_1 = T.alloc_buffer((T.int64(4), T.int64(32),
T.int64(16), T.int64(8)))
+ T_reshape = T.alloc_buffer((T.int64(128), T.int64(16), T.int64(8)))
+ T_transpose_2 = T.alloc_buffer((T.int64(4), T.int64(32),
T.int64(8), T.int64(8)))
+ T_reshape_1 = T.alloc_buffer((T.int64(128), T.int64(8),
T.int64(8)))
+ T_batch_matmul_NT = T.alloc_buffer((T.int64(128), T.int64(16),
T.int64(8)))
+ T_multiply = T.alloc_buffer((T.int64(128), T.int64(16),
T.int64(8)))
+ T_reshape_2 = T.alloc_buffer((T.int64(4), T.int64(32),
T.int64(16), T.int64(8)))
+ T_add = T.alloc_buffer((T.int64(4), T.int64(32), T.int64(16),
T.int64(8)))
+ T_reshape_3 = T.alloc_buffer((T.int64(128), T.int64(16),
T.int64(8)))
+ T_softmax_maxelem = T.alloc_buffer((T.int64(128), T.int64(16)))
+ T_softmax_exp = T.alloc_buffer((T.int64(128), T.int64(16),
T.int64(8)))
+ T_softmax_expsum = T.alloc_buffer((T.int64(128), T.int64(16)))
+ T_softmax_norm = T.alloc_buffer((T.int64(128), T.int64(16),
T.int64(8)))
+ T_transpose_3 = T.alloc_buffer((T.int64(4), T.int64(32),
T.int64(8), T.int64(16)))
+ T_reshape_4 = T.alloc_buffer((T.int64(128), T.int64(8),
T.int64(16)))
+ T_batch_matmul_NN = T.alloc_buffer((T.int64(128), T.int64(16),
T.int64(16)))
+ T_reshape_5 = T.alloc_buffer((T.int64(4), T.int64(32),
T.int64(16), T.int64(16)))
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32),
T.int64(16), T.int64(8)):
+ with T.block("T_transpose"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder[v_ax0, v_ax2, v_ax1, v_ax3])
+ T.writes(T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_transpose_1[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder[v_ax0, v_ax2, v_ax1, v_ax3]
+ for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(T_transpose_1[((v_ax2 // T.int64(8) + v_ax1) //
T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) +
v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) %
T.int64(16), v_ax2 % T.int64(8)])
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2])
+ T_reshape[v_ax0, v_ax1, v_ax2] = T_transpose_1[((v_ax2 //
T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32),
((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 //
T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32),
T.int64(8), T.int64(8)):
+ with T.block("T_transpose_1"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder_1[v_ax0, v_ax2, v_ax1, v_ax3])
+ T.writes(T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_transpose_2[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder_1[v_ax0, v_ax2, v_ax1, v_ax3]
+ for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(8)):
+ with T.block("T_reshape_1"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(T_transpose_2[((v_ax2 // T.int64(8) + v_ax1) //
T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) +
v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) %
T.int64(8), v_ax2 % T.int64(8)])
+ T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2])
+ T_reshape_1[v_ax0, v_ax1, v_ax2] = T_transpose_2[((v_ax2
// T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32),
((v_ax2 // T.int64(8) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 //
T.int64(8) + v_ax1) % T.int64(8), v_ax2 % T.int64(8)]
+ for b, i, j, k in T.grid(T.int64(128), T.int64(16), T.int64(8),
T.int64(8)):
+ with T.block("T_batch_matmul_NT"):
+ v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k])
+ T.reads(T_reshape[v_b, v_i, v_k], T_reshape_1[v_b, v_j,
v_k])
+ T.writes(T_batch_matmul_NT[v_b, v_i, v_j])
+ T.block_attr({"layout_free_placeholders": [T_reshape_1]})
+ with T.init():
+ T_batch_matmul_NT[v_b, v_i, v_j] = T.float32(0)
+ T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[v_b,
v_i, v_j] + T_reshape[v_b, v_i, v_k] * T_reshape_1[v_b, v_j, v_k]
+ for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(T_batch_matmul_NT[v_ax0, v_ax1, v_ax2])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
+ T_multiply[v_ax0, v_ax1, v_ax2] = T_batch_matmul_NT[v_ax0,
v_ax1, v_ax2] * T.float32(0.10000000000000001)
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32),
T.int64(16), T.int64(8)):
+ with T.block("T_reshape_2"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_multiply[(v_ax0 * T.int64(32) + (v_ax3 //
T.int64(8) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 //
T.int64(8) + v_ax2) % T.int64(16), v_ax3 % T.int64(8)])
+ T.writes(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3] =
T_multiply[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(8) + v_ax2) // T.int64(16)
+ v_ax1) % T.int64(128), (v_ax3 // T.int64(8) + v_ax2) % T.int64(16), v_ax3 %
T.int64(8)]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32),
T.int64(16), T.int64(8)):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_reshape_2[v_ax0, v_ax1, v_ax2, v_ax3],
rxplaceholder_3[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_add[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_add[v_ax0, v_ax1, v_ax2, v_ax3] = T_reshape_2[v_ax0,
v_ax1, v_ax2, v_ax3] + rxplaceholder_3[v_ax0, v_ax1, v_ax2, v_ax3]
+ for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
+ with T.block("T_reshape_3"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(T_add[((v_ax2 // T.int64(8) + v_ax1) //
T.int64(16) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(8) +
v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 // T.int64(8) + v_ax1) %
T.int64(16), v_ax2 % T.int64(8)])
+ T.writes(T_reshape_3[v_ax0, v_ax1, v_ax2])
+ T_reshape_3[v_ax0, v_ax1, v_ax2] = T_add[((v_ax2 //
T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(128) // T.int64(32),
((v_ax2 // T.int64(8) + v_ax1) // T.int64(16) + v_ax0) % T.int64(32), (v_ax2 //
T.int64(8) + v_ax1) % T.int64(16), v_ax2 % T.int64(8)]
+ for i0, i1, k in T.grid(T.int64(128), T.int64(16), T.int64(8)):
+ with T.block("T_softmax_maxelem"):
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+ T.reads(T_reshape_3[v_i0, v_i1, v_k])
+ T.writes(T_softmax_maxelem[v_i0, v_i1])
+ with T.init():
+ T_softmax_maxelem[v_i0, v_i1] =
T.float32(-3.4028234663852886e+38)
+ T_softmax_maxelem[v_i0, v_i1] =
T.max(T_softmax_maxelem[v_i0, v_i1], T_reshape_3[v_i0, v_i1, v_k])
+ for i0, i1, i2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
+ with T.block("T_softmax_exp"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(T_reshape_3[v_i0, v_i1, v_i2],
T_softmax_maxelem[v_i0, v_i1])
+ T.writes(T_softmax_exp[v_i0, v_i1, v_i2])
+ T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(T_reshape_3[v_i0,
v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1])
+ for i0, i1, k in T.grid(T.int64(128), T.int64(16), T.int64(8)):
+ with T.block("T_softmax_expsum"):
+ v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
+ T.reads(T_softmax_exp[v_i0, v_i1, v_k])
+ T.writes(T_softmax_expsum[v_i0, v_i1])
+ with T.init():
+ T_softmax_expsum[v_i0, v_i1] = T.float32(0)
+ T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0,
v_i1] + T_softmax_exp[v_i0, v_i1, v_k]
+ for i0, i1, i2 in T.grid(T.int64(128), T.int64(16), T.int64(8)):
+ with T.block("T_softmax_norm"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(T_softmax_exp[v_i0, v_i1, v_i2],
T_softmax_expsum[v_i0, v_i1])
+ T.writes(T_softmax_norm[v_i0, v_i1, v_i2])
+ T.block_attr({"axis": 2})
+ T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0,
v_i1, v_i2] / T_softmax_expsum[v_i0, v_i1]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32),
T.int64(8), T.int64(16)):
+ with T.block("T_transpose_2"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(rxplaceholder_2[v_ax0, v_ax2, v_ax1, v_ax3])
+ T.writes(T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_transpose_3[v_ax0, v_ax1, v_ax2, v_ax3] =
rxplaceholder_2[v_ax0, v_ax2, v_ax1, v_ax3]
+ for ax0, ax1, ax2 in T.grid(T.int64(128), T.int64(8), T.int64(16)):
+ with T.block("T_reshape_4"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(T_transpose_3[((v_ax2 // T.int64(16) + v_ax1) //
T.int64(8) + v_ax0) % T.int64(128) // T.int64(32), ((v_ax2 // T.int64(16) +
v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 // T.int64(16) + v_ax1) %
T.int64(8), v_ax2 % T.int64(16)])
+ T.writes(T_reshape_4[v_ax0, v_ax1, v_ax2])
+ T_reshape_4[v_ax0, v_ax1, v_ax2] = T_transpose_3[((v_ax2
// T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(128) // T.int64(32),
((v_ax2 // T.int64(16) + v_ax1) // T.int64(8) + v_ax0) % T.int64(32), (v_ax2 //
T.int64(16) + v_ax1) % T.int64(8), v_ax2 % T.int64(16)]
+ for b, i, j, k in T.grid(T.int64(128), T.int64(16), T.int64(16),
T.int64(8)):
+ with T.block("T_batch_matmul_NN"):
+ v_b, v_i, v_j, v_k = T.axis.remap("SSSR", [b, i, j, k])
+ T.reads(T_softmax_norm[v_b, v_i, v_k], T_reshape_4[v_b,
v_k, v_j])
+ T.writes(T_batch_matmul_NN[v_b, v_i, v_j])
+ T.block_attr({"layout_free_placeholders": [T_reshape_4]})
+ with T.init():
+ T_batch_matmul_NN[v_b, v_i, v_j] = T.float32(0)
+ T_batch_matmul_NN[v_b, v_i, v_j] = T_batch_matmul_NN[v_b,
v_i, v_j] + T_softmax_norm[v_b, v_i, v_k] * T_reshape_4[v_b, v_k, v_j]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(32),
T.int64(16), T.int64(16)):
+ with T.block("T_reshape_5"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_batch_matmul_NN[(v_ax0 * T.int64(32) + (v_ax3 //
T.int64(16) + v_ax2) // T.int64(16) + v_ax1) % T.int64(128), (v_ax3 //
T.int64(16) + v_ax2) % T.int64(16), v_ax3 % T.int64(16)])
+ T.writes(T_reshape_5[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_reshape_5[v_ax0, v_ax1, v_ax2, v_ax3] =
T_batch_matmul_NN[(v_ax0 * T.int64(32) + (v_ax3 // T.int64(16) + v_ax2) //
T.int64(16) + v_ax1) % T.int64(128), (v_ax3 // T.int64(16) + v_ax2) %
T.int64(16), v_ax3 % T.int64(16)]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(16),
T.int64(32), T.int64(16)):
+ with T.block("T_transpose_3"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_reshape_5[v_ax0, v_ax2, v_ax1, v_ax3])
+ T.writes(T_transpose[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_transpose[v_ax0, v_ax1, v_ax2, v_ax3] =
T_reshape_5[v_ax0, v_ax2, v_ax1, v_ax3]
+
+ @R.function
+ def main(q: R.Tensor((4, 16, 32, 8), dtype="float32"), k: R.Tensor((4,
8, 32, 8), dtype="float32"), v: R.Tensor((4, 8, 32, 16), dtype="float32"),
bias: R.Tensor((4, 32, 16, 8), dtype="float32")) -> R.Tensor((4, 16, 32, 16),
dtype="float32"):
+ gv = R.call_tir(Expected.attention_bias, (q, k, v, bias),
out_sinfo=R.Tensor((4, 16, 32, 16), dtype="float32"))
+ return gv
+ # fmt: on
+
+ mod = LegalizeOps()(Attention)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()