This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 76946b4575 Resolving inconsistency between attention/attention_bias 
(#18038)
76946b4575 is described below

commit 76946b4575828d32c53167984f7fcfe976c76834
Author: Taylor <[email protected]>
AuthorDate: Fri Jun 6 05:00:44 2025 +0800

    Resolving inconsistency between attention/attention_bias (#18038)
    
    * Resolving inconsistency between attention/attention_bias
    
    * reformat
    
    * reduce the length of line
    
    ---------
    
    Co-authored-by: taylor <[email protected]>
---
 python/tvm/relax/op/nn/__init__.py |  1 +
 python/tvm/relax/op/nn/nn.py       | 97 ++++++++++++++++++++++++++++++++++++++
 2 files changed, 98 insertions(+)

diff --git a/python/tvm/relax/op/nn/__init__.py 
b/python/tvm/relax/op/nn/__init__.py
index 62fa0d53a9..a10f2caeab 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -20,6 +20,7 @@ from .nn import (
     adaptive_avg_pool2d,
     adaptive_avg_pool3d,
     attention,
+    attention_bias,
     attention_var_len,
     avg_pool1d,
     avg_pool2d,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index c6beea3158..a64055ba4f 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -1977,6 +1977,103 @@ def attention(
     )  # type: ignore
 
 
+def attention_bias(
+    query: Expr,
+    key: Expr,
+    value: Expr,
+    bias: Optional[Expr] = None,
+    scale: Optional[FloatImm] = None,
+    causal_mask: Optional[str] = None,
+    window_size: Optional[int] = None,
+) -> Expr:
+    r"""Computes fused multi head attention.
+
+    IRModule.script() transforms attention op to attention_bias which is 
incompatible
+    with TVMScript Parser.
+    The function makes TVMScript's print compatible with TVMScript's parser.
+
+    All input tensors are of 4-D tensors with BSNH layout.
+
+    .. math::
+        FMA(Q, K, V) = \text{Softmax}(Q @ K^T) @ V
+
+    .. note::
+        The input tensor is required to have float16 dtype
+
+    Parameters
+    ----------
+    query: relax.Expr
+        The input query to the operator. The layout of the input query should 
be
+        (batch_size, seq_len, num_head, head_dim).
+
+    key: relax.Expr
+        The input key to the operator. The layout of the input key should be
+        (batch_size, seq_len_kv, num_head, head_dim).
+
+    value: relax.Expr
+        The input value to the operator. The layout of the input value should 
be
+        (batch_size, seq_len_kv, num_head, head_dim_v).
+
+    bias: Optional[Expr]
+        The optional attention bias to the operator. The layout of the 
attention bias should be
+        a 4-D tensor ending with seq_len_kv, and broadcastable to
+        (batch_size, num_head, seq_len, seq_len_kv).
+
+    scale: Optional[float]
+        The scale value to be applied to the attention score, by default 1 / 
sqrt(head_dim).
+
+    causal_mask: Optional[str]
+        The optional causal mask, i.e. 'TopLeft' and 'BottomRight'.
+        For 'TopLeft', the mask matrix is as `np.tril(*, k=0)`,
+        while for 'BottomRight', the mask matrix is as `np.tril(*, 
k=abs(seq_len - seq_len_kv))`
+        For example, with seq_len = 4, seq_len_kv = 2,
+        mask for 'TopLeft':
+
+        .. code:: python
+
+            [[1, 0],
+            [1, 1],
+            [1, 1],
+            [1, 1]]
+
+        mask for 'BottomRight':
+
+        .. code:: python
+
+            [[1, 1],
+            [1, 1],
+            [1, 1],
+            [1, 1]]
+
+        with seq_len = 2, seq_len_kv = 4,
+        mask for 'TopLeft':
+
+        .. code:: python
+
+            [[1, 0, 0, 0],
+            [1, 1, 0, 0]]
+
+        mask for 'BottomRight':
+
+        .. code:: python
+
+            [[1, 1, 1, 0],
+            [1, 1, 1, 1]]
+
+    window_size: Optional[int]
+        The size of the window for sliding-window attention.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result. The layout of the output should be
+        (batch_size, seq_len, num_head, head_dim_v).
+    """
+    return _ffi_api.attention(
+        query, key, value, bias, scale, causal_mask, window_size
+    )  # type: ignore
+
+
 def attention_var_len(
     queries: Expr,
     keys: Expr,

Reply via email to