This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ef39f37f49 [Unity][BYOC] Support variable-length attention by flash 
attention (#15959)
ef39f37f49 is described below

commit ef39f37f49c8dbb885d56a92e97427d3e4ec10c4
Author: masahi <masahi...@gmail.com>
AuthorDate: Thu Oct 26 07:43:25 2023 +0900

    [Unity][BYOC] Support variable-length attention by flash attention (#15959)
    
    * works
    
    * add test
    
    * fix tests
---
 3rdparty/libflash_attn                            |   2 +-
 python/tvm/contrib/cutlass/attention_operation.py |  56 +++++++
 python/tvm/contrib/cutlass/gen_tensor_op.py       |  16 +-
 tests/python/relax/test_codegen_cutlass.py        | 190 ++++++++++++++++------
 4 files changed, 206 insertions(+), 58 deletions(-)

diff --git a/3rdparty/libflash_attn b/3rdparty/libflash_attn
index c1d793ad93..55d3603f74 160000
--- a/3rdparty/libflash_attn
+++ b/3rdparty/libflash_attn
@@ -1 +1 @@
-Subproject commit c1d793ad939c8ec3cec351db84bc80808e4d34c3
+Subproject commit 55d3603f741eb68e82640ff55ccea4b17dd8053e
diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
index 5579819001..7084a105c8 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -279,3 +279,59 @@ def instantiate_flash_attention_template(attrs):
         return substitute_template(template_stacked, attrs)
 
     return substitute_template(template, attrs)
+
+
+def instantiate_flash_attention_var_len_template(attrs):
+    """Return host code for flash attention with variable sequence lengths."""
+
+    template = """
+    int _max_seqlen_q, _max_seqlen_k;
+    cudaMemcpy(&_max_seqlen_q, (int32_t*)${max_seqlen_q}->data, 
sizeof(int32_t),
+               cudaMemcpyDeviceToHost);
+    cudaMemcpy(&_max_seqlen_k, (int32_t*)${max_seqlen_k}->data, 
sizeof(int32_t),
+               cudaMemcpyDeviceToHost);
+
+    int batch_size = ${seqstart_q}->shape[0] - 1;
+
+    int q_head_stride = ${head_dim};
+    int k_head_stride = ${head_dim};
+    int v_head_stride = ${head_dim};
+    int o_head_stride = ${head_dim};
+    int q_row_stride = q_head_stride * ${num_q_heads};
+    int k_row_stride = k_head_stride * ${num_kv_heads};
+    int v_row_stride = v_head_stride * ${num_kv_heads};
+    int o_row_stride = o_head_stride * ${num_q_heads};
+
+    auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+    ICHECK(func != nullptr);
+    cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator 
void*());
+
+    flash_attn::flash_attention_var_len_forward(
+                            static_cast<const 
cutlass::half_t*>(${query}->data),
+                           static_cast<const cutlass::half_t*>(${key}->data),
+                           static_cast<const cutlass::half_t*>(${value}->data),
+                            static_cast<const int*>(${seqstart_q}->data),
+                            static_cast<const int*>(${seqstart_k}->data),
+                           static_cast<cutlass::half_t*>(out0->data),
+                           batch_size,
+                           _max_seqlen_q,
+                           _max_seqlen_k,
+                           ${num_q_heads},
+                           ${num_kv_heads},
+                           ${head_dim},
+                           q_head_stride,
+                           k_head_stride,
+                           v_head_stride,
+                           o_head_stride,
+                           q_row_stride,
+                           k_row_stride,
+                           v_row_stride,
+                           o_row_stride,
+                           ${scale},
+                           ${is_causal},
+                            ${is_causal} ? _max_seqlen_k : -1,
+                            ${window_size_right},
+                           stream);
+    """
+
+    return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index e86a02df60..d42791d71b 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -32,6 +32,7 @@ from . import _ffi_api as ffi
 from .attention_operation import (
     instantiate_attention_template,
     instantiate_flash_attention_template,
+    instantiate_flash_attention_var_len_template,
 )
 from .conv2d_operation import instantiate_conv2d_template
 from .gemm_operation import instantiate_gemm_template, emit_fp16A_intB_matmul
@@ -778,7 +779,6 @@ def instantiate_template(func_name, annotations, func_args):
             )
             # Flash v2 is currently not supported for sm < 80
             and int(annotations["arch"]) >= 80
-            and not is_var_len
         )
 
         if "window_size" in annotations:
@@ -789,15 +789,23 @@ def instantiate_template(func_name, annotations, 
func_args):
             attrs["window_size_left"] = int(annotations["window_size"]) - 1
             attrs["window_size_right"] = 0
         else:
-            attrs["window_size_left"] = -1
-            attrs["window_size_right"] = -1
+            if int(annotations["custom_mask_type"]) == 2:
+                attrs["window_size_left"] = attrs["num_keys"]
+                attrs["window_size_right"] = 0
+            else:
+                attrs["window_size_left"] = -1
+                attrs["window_size_right"] = -1
 
         if use_flash:
             headers.append("flash.h")
             attrs["is_causal"] = int(annotations["custom_mask_type"]) == 2
             attrs["num_q_heads"] = annotations["num_q_heads"]
             attrs["num_kv_heads"] = annotations["num_kv_heads"]
-            code = instantiate_flash_attention_template(attrs)
+
+            if is_var_len:
+                code = instantiate_flash_attention_var_len_template(attrs)
+            else:
+                code = instantiate_flash_attention_template(attrs)
         else:
             headers.append("kernel_forward.h")
 
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 151e05e9b6..f07a0dfcbb 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -611,13 +611,35 @@ def get_relax_attention_module(
     return tvm.IRModule({"main": func})
 
 
-@memoize("topi.tests.test_codegen_cutlass.test_attention_offload")
 def get_numpy_attention_ref(
-    b, s, s_kv, n, h, h_v, bias_shape, qk_scale, causal, dtype, 
window_size=None
+    b,
+    s,
+    s_kv,
+    n,
+    h,
+    h_v,
+    bias_shape,
+    qk_scale,
+    causal,
+    dtype,
+    window_size=None,
+    num_kv_head=None,
 ):
+    if num_kv_head is None:
+        num_kv_head = n
+
     q = np.random.randn(b, s, n, h).astype(dtype)
-    k = np.random.randn(b, s_kv, n, h).astype(dtype)
-    v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
+    k_orig = np.random.randn(b, s_kv, num_kv_head, h).astype(dtype)
+    v_orig = np.random.randn(b, s_kv, num_kv_head, h_v).astype(dtype)
+
+    if num_kv_head is None:
+        k = k_orig
+        v = v_orig
+    else:
+        factor = n // num_kv_head
+        k = np.repeat(k_orig, factor, axis=2)
+        v = np.repeat(v_orig, factor, axis=2)
+
     qt = q.transpose(0, 2, 1, 3)  # b, n, s, h
     kt = k.transpose(0, 2, 3, 1)  # b, n, h, s_kv
     if not qk_scale == "none":
@@ -655,7 +677,7 @@ def get_numpy_attention_ref(
 
     vt = v.transpose(0, 2, 1, 3)  # b, n, s_kv, h_v
     ref = attn @ vt  # b, n, s, h_v
-    return q, k, v, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
+    return q, k_orig, v_orig, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
 
 
 def test_attention_offload(attention_size, attention_dtype):
@@ -1191,6 +1213,7 @@ def test_attention_rewrite_fp16():
             v: R.Tensor((4, 8, 32, 16), dtype="float16"),
             bias: R.Tensor((4, 32, 16, 8), dtype="float16"),
         ) -> R.Tensor((4, 16, 32, 16), dtype="float16"):
+            R.func_attr({"num_input": 4})
             with R.dataflow():
                 lv = R.permute_dims(q, axes=[0, 2, 1, 3])
                 lv1 = R.reshape(lv, R.shape([128, 16, 8]))
@@ -1262,7 +1285,8 @@ def test_attention_rewrite_fp16():
             k: R.Tensor((4, 8, 32, 8), dtype="float16"),
             v: R.Tensor((4, 8, 32, 16), dtype="float16"),
             bias: R.Tensor((4, 32, 16, 8), dtype="float16"),
-        ) -> R.Tensor((4, 16, 32, 16), dtype="float16"):
+        ) -> R.Tensor((4, 16, 32, 16), dtype="float32"):
+            R.func_attr({"num_input": 4})
             cls = Expected
             with R.dataflow():
                 lv = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), 
R.dtype("uint8"))
@@ -2016,49 +2040,10 @@ def test_attention_rewrite_multi_query():
     tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
-def test_batched_var_len_attention():
+def _test_batched_var_len_attention(mod, seq_lens, num_head, num_kv_head, 
head_size):
     if not tvm.get_global_func("tvm.contrib.thrust.sum_scan", True):
         return
 
-    @I.ir_module
-    class Module:
-        @R.function
-        def main(
-            queries: R.Tensor(("num_tokens", 4096), dtype="float16"),
-            keys: R.Tensor(("num_tokens", 4096), dtype="float16"),
-            values: R.Tensor(("num_tokens", 4096), dtype="float16"),
-            seq_lens: R.Tensor(("num_seq",), dtype="int32"),
-        ) -> R.Tensor(("num_tokens", 4096), dtype="float16"):
-            cls = Module
-            num_tokens = T.int64()
-            num_seq = T.int64()
-
-            with R.dataflow():
-                # TODO(masahi): Workaround for the broken Relax cumsum op on 
GPU.
-                # https://github.com/apache/tvm/issues/15851
-                cumsum = R.call_dps_packed(
-                    "tvm.contrib.thrust.sum_scan", seq_lens, 
out_sinfo=seq_lens.struct_info
-                )
-                max_seqlen_q = R.max(seq_lens)
-                seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum])
-                q = R.reshape(queries, R.shape([1, num_tokens, 128, 32]))
-                k = R.reshape(keys, R.shape([1, num_tokens, 128, 32]))
-                v = R.reshape(values, R.shape([1, num_tokens, 128, 32]))
-                attn_out = R.nn.attention_var_len(
-                    q,
-                    k,
-                    v,
-                    seqstart_q,
-                    max_seqlen_q,
-                    causal_mask="BottomRight",
-                )
-                out = R.reshape(attn_out, R.shape([num_tokens, 4096]))
-                R.output(out)
-            return out
-
-    seq_lens = [5, 3, 8]
-    num_head = 128
-    head_size = 32
     hidden_size = num_head * head_size
 
     batched_queries = []
@@ -2068,11 +2053,21 @@ def test_batched_var_len_attention():
 
     for s in seq_lens:
         q, k, v, _, ref = get_numpy_attention_ref(
-            1, s, s, num_head, head_size, head_size, "none", "none", 
"BottomRight", "float16"
+            1,
+            s,
+            s,
+            num_head,
+            head_size,
+            head_size,
+            "none",
+            "none",
+            "BottomRight",
+            "float16",
+            num_kv_head=num_kv_head,
         )
         batched_queries.append(np.reshape(q, [-1, hidden_size]))
-        batched_keys.append(np.reshape(k, [-1, hidden_size]))
-        batched_values.append(np.reshape(v, [-1, hidden_size]))
+        batched_keys.append(np.reshape(k, [-1, num_kv_head * head_size]))
+        batched_values.append(np.reshape(v, [-1, num_kv_head * head_size]))
         batched_refs.append(np.reshape(ref, [-1, hidden_size]))
 
     batched_queries = np.vstack(batched_queries)
@@ -2080,7 +2075,7 @@ def test_batched_var_len_attention():
     batched_values = np.vstack(batched_values)
     ref = np.vstack(batched_refs)
 
-    mod = partition_for_cutlass(Module)
+    mod = partition_for_cutlass(mod)
     codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80}})
     mod = codegen_pass(mod)
 
@@ -2099,8 +2094,6 @@ def test_batched_var_len_attention():
         "cuda",
     )
 
-    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
-
     ############# xformer reference for verification #############
 
     # attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
@@ -2115,7 +2108,98 @@ def test_batched_var_len_attention():
     # ).cpu().numpy()[0]
     # out = np.reshape(out, [-1, hidden_size])
 
-    # tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
+def test_batched_var_len_attention():
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(
+            queries: R.Tensor(("num_tokens", 4096), dtype="float16"),
+            keys: R.Tensor(("num_tokens", 4096), dtype="float16"),
+            values: R.Tensor(("num_tokens", 4096), dtype="float16"),
+            seq_lens: R.Tensor(("num_seq",), dtype="int32"),
+        ) -> R.Tensor(("num_tokens", 4096), dtype="float16"):
+            R.func_attr({"num_input": 4})
+            cls = Module
+            num_tokens = T.int64()
+            num_seq = T.int64()
+
+            with R.dataflow():
+                # TODO(masahi): Workaround for the broken Relax cumsum op on 
GPU.
+                # https://github.com/apache/tvm/issues/15851
+                cumsum = R.call_dps_packed(
+                    "tvm.contrib.thrust.sum_scan", seq_lens, 
out_sinfo=seq_lens.struct_info
+                )
+                max_seqlen_q = R.max(seq_lens)
+                seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum])
+                q = R.reshape(queries, R.shape([1, num_tokens, 128, 32]))
+                k = R.reshape(keys, R.shape([1, num_tokens, 128, 32]))
+                v = R.reshape(values, R.shape([1, num_tokens, 128, 32]))
+                attn_out = R.nn.attention_var_len(
+                    q,
+                    k,
+                    v,
+                    seqstart_q,
+                    max_seqlen_q,
+                    causal_mask="BottomRight",
+                )
+                out = R.reshape(attn_out, R.shape([num_tokens, 4096]))
+                R.output(out)
+            return out
+
+    seq_lens = [5, 3, 8]
+    num_head = 128
+    head_size = 32
+
+    _test_batched_var_len_attention(Module, seq_lens, num_head, num_head, 
head_size)
+
+
+def test_batched_var_len_multi_query_attention():
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(
+            queries: R.Tensor(("num_tokens", 4096), dtype="float16"),
+            keys: R.Tensor(("num_tokens", 512), dtype="float16"),
+            values: R.Tensor(("num_tokens", 512), dtype="float16"),
+            seq_lens: R.Tensor(("num_seq",), dtype="int32"),
+        ) -> R.Tensor(("num_tokens", 4096), dtype="float16"):
+            R.func_attr({"num_input": 4})
+            cls = Module
+            num_tokens = T.int64()
+            num_seq = T.int64()
+
+            with R.dataflow():
+                # TODO(masahi): Workaround for the broken Relax cumsum op on 
GPU.
+                # https://github.com/apache/tvm/issues/15851
+                cumsum = R.call_dps_packed(
+                    "tvm.contrib.thrust.sum_scan", seq_lens, 
out_sinfo=seq_lens.struct_info
+                )
+                max_seqlen_q = R.max(seq_lens)
+                seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum])
+                q = R.reshape(queries, R.shape([1, num_tokens, 128, 32]))
+                k = R.reshape(keys, R.shape([1, num_tokens, 16, 32]))
+                v = R.reshape(values, R.shape([1, num_tokens, 16, 32]))
+                attn_out = R.nn.attention_var_len(
+                    q,
+                    k,
+                    v,
+                    seqstart_q,
+                    max_seqlen_q,
+                    causal_mask="BottomRight",
+                )
+                out = R.reshape(attn_out, R.shape([num_tokens, 4096]))
+                R.output(out)
+            return out
+
+    seq_lens = [5, 3, 8]
+    num_head = 128
+    num_kv_head = 16
+    head_size = 32
+
+    _test_batched_var_len_attention(Module, seq_lens, num_head, num_kv_head, 
head_size)
 
 
 def test_sliding_window():

Reply via email to