This is an automated email from the ASF dual-hosted git repository.
yaxingcai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 76946b4575 Resolving inconsistency between attention/attention_bias
(#18038)
76946b4575 is described below
commit 76946b4575828d32c53167984f7fcfe976c76834
Author: Taylor <[email protected]>
AuthorDate: Fri Jun 6 05:00:44 2025 +0800
Resolving inconsistency between attention/attention_bias (#18038)
* Resolving inconsistency between attention/attention_bias
* reformat
* reduce the length of line
---------
Co-authored-by: taylor <[email protected]>
---
python/tvm/relax/op/nn/__init__.py | 1 +
python/tvm/relax/op/nn/nn.py | 97 ++++++++++++++++++++++++++++++++++++++
2 files changed, 98 insertions(+)
diff --git a/python/tvm/relax/op/nn/__init__.py
b/python/tvm/relax/op/nn/__init__.py
index 62fa0d53a9..a10f2caeab 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -20,6 +20,7 @@ from .nn import (
adaptive_avg_pool2d,
adaptive_avg_pool3d,
attention,
+ attention_bias,
attention_var_len,
avg_pool1d,
avg_pool2d,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index c6beea3158..a64055ba4f 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -1977,6 +1977,103 @@ def attention(
) # type: ignore
+def attention_bias(
+ query: Expr,
+ key: Expr,
+ value: Expr,
+ bias: Optional[Expr] = None,
+ scale: Optional[FloatImm] = None,
+ causal_mask: Optional[str] = None,
+ window_size: Optional[int] = None,
+) -> Expr:
+ r"""Computes fused multi head attention.
+
+ IRModule.script() transforms attention op to attention_bias which is
incompatible
+ with TVMScript Parser.
+ The function makes TVMScript's print compatible with TVMScript's parser.
+
+ All input tensors are of 4-D tensors with BSNH layout.
+
+ .. math::
+ FMA(Q, K, V) = \text{Softmax}(Q @ K^T) @ V
+
+ .. note::
+ The input tensor is required to have float16 dtype
+
+ Parameters
+ ----------
+ query: relax.Expr
+ The input query to the operator. The layout of the input query should
be
+ (batch_size, seq_len, num_head, head_dim).
+
+ key: relax.Expr
+ The input key to the operator. The layout of the input key should be
+ (batch_size, seq_len_kv, num_head, head_dim).
+
+ value: relax.Expr
+ The input value to the operator. The layout of the input value should
be
+ (batch_size, seq_len_kv, num_head, head_dim_v).
+
+ bias: Optional[Expr]
+ The optional attention bias to the operator. The layout of the
attention bias should be
+ a 4-D tensor ending with seq_len_kv, and broadcastable to
+ (batch_size, num_head, seq_len, seq_len_kv).
+
+ scale: Optional[float]
+ The scale value to be applied to the attention score, by default 1 /
sqrt(head_dim).
+
+ causal_mask: Optional[str]
+ The optional causal mask, i.e. 'TopLeft' and 'BottomRight'.
+ For 'TopLeft', the mask matrix is as `np.tril(*, k=0)`,
+ while for 'BottomRight', the mask matrix is as `np.tril(*,
k=abs(seq_len - seq_len_kv))`
+ For example, with seq_len = 4, seq_len_kv = 2,
+ mask for 'TopLeft':
+
+ .. code:: python
+
+ [[1, 0],
+ [1, 1],
+ [1, 1],
+ [1, 1]]
+
+ mask for 'BottomRight':
+
+ .. code:: python
+
+ [[1, 1],
+ [1, 1],
+ [1, 1],
+ [1, 1]]
+
+ with seq_len = 2, seq_len_kv = 4,
+ mask for 'TopLeft':
+
+ .. code:: python
+
+ [[1, 0, 0, 0],
+ [1, 1, 0, 0]]
+
+ mask for 'BottomRight':
+
+ .. code:: python
+
+ [[1, 1, 1, 0],
+ [1, 1, 1, 1]]
+
+ window_size: Optional[int]
+ The size of the window for sliding-window attention.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result. The layout of the output should be
+ (batch_size, seq_len, num_head, head_dim_v).
+ """
+ return _ffi_api.attention(
+ query, key, value, bias, scale, causal_mask, window_size
+ ) # type: ignore
+
+
def attention_var_len(
queries: Expr,
keys: Expr,