This is an automated email from the ASF dual-hosted git repository. wuwei pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new ef39f37f49 [Unity][BYOC] Support variable-length attention by flash attention (#15959) ef39f37f49 is described below commit ef39f37f49c8dbb885d56a92e97427d3e4ec10c4 Author: masahi <masahi...@gmail.com> AuthorDate: Thu Oct 26 07:43:25 2023 +0900 [Unity][BYOC] Support variable-length attention by flash attention (#15959) * works * add test * fix tests --- 3rdparty/libflash_attn | 2 +- python/tvm/contrib/cutlass/attention_operation.py | 56 +++++++ python/tvm/contrib/cutlass/gen_tensor_op.py | 16 +- tests/python/relax/test_codegen_cutlass.py | 190 ++++++++++++++++------ 4 files changed, 206 insertions(+), 58 deletions(-) diff --git a/3rdparty/libflash_attn b/3rdparty/libflash_attn index c1d793ad93..55d3603f74 160000 --- a/3rdparty/libflash_attn +++ b/3rdparty/libflash_attn @@ -1 +1 @@ -Subproject commit c1d793ad939c8ec3cec351db84bc80808e4d34c3 +Subproject commit 55d3603f741eb68e82640ff55ccea4b17dd8053e diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 5579819001..7084a105c8 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -279,3 +279,59 @@ def instantiate_flash_attention_template(attrs): return substitute_template(template_stacked, attrs) return substitute_template(template, attrs) + + +def instantiate_flash_attention_var_len_template(attrs): + """Return host code for flash attention with variable sequence lengths.""" + + template = """ + int _max_seqlen_q, _max_seqlen_k; + cudaMemcpy(&_max_seqlen_q, (int32_t*)${max_seqlen_q}->data, sizeof(int32_t), + cudaMemcpyDeviceToHost); + cudaMemcpy(&_max_seqlen_k, (int32_t*)${max_seqlen_k}->data, sizeof(int32_t), + cudaMemcpyDeviceToHost); + + int batch_size = ${seqstart_q}->shape[0] - 1; + + int q_head_stride = ${head_dim}; + int k_head_stride = ${head_dim}; + int v_head_stride = ${head_dim}; + int o_head_stride = ${head_dim}; + int q_row_stride = q_head_stride * ${num_q_heads}; + int k_row_stride = k_head_stride * ${num_kv_heads}; + int v_row_stride = v_head_stride * ${num_kv_heads}; + int o_row_stride = o_head_stride * ${num_q_heads}; + + auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream"); + ICHECK(func != nullptr); + cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator void*()); + + flash_attn::flash_attention_var_len_forward( + static_cast<const cutlass::half_t*>(${query}->data), + static_cast<const cutlass::half_t*>(${key}->data), + static_cast<const cutlass::half_t*>(${value}->data), + static_cast<const int*>(${seqstart_q}->data), + static_cast<const int*>(${seqstart_k}->data), + static_cast<cutlass::half_t*>(out0->data), + batch_size, + _max_seqlen_q, + _max_seqlen_k, + ${num_q_heads}, + ${num_kv_heads}, + ${head_dim}, + q_head_stride, + k_head_stride, + v_head_stride, + o_head_stride, + q_row_stride, + k_row_stride, + v_row_stride, + o_row_stride, + ${scale}, + ${is_causal}, + ${is_causal} ? _max_seqlen_k : -1, + ${window_size_right}, + stream); + """ + + return substitute_template(template, attrs) diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index e86a02df60..d42791d71b 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -32,6 +32,7 @@ from . import _ffi_api as ffi from .attention_operation import ( instantiate_attention_template, instantiate_flash_attention_template, + instantiate_flash_attention_var_len_template, ) from .conv2d_operation import instantiate_conv2d_template from .gemm_operation import instantiate_gemm_template, emit_fp16A_intB_matmul @@ -778,7 +779,6 @@ def instantiate_template(func_name, annotations, func_args): ) # Flash v2 is currently not supported for sm < 80 and int(annotations["arch"]) >= 80 - and not is_var_len ) if "window_size" in annotations: @@ -789,15 +789,23 @@ def instantiate_template(func_name, annotations, func_args): attrs["window_size_left"] = int(annotations["window_size"]) - 1 attrs["window_size_right"] = 0 else: - attrs["window_size_left"] = -1 - attrs["window_size_right"] = -1 + if int(annotations["custom_mask_type"]) == 2: + attrs["window_size_left"] = attrs["num_keys"] + attrs["window_size_right"] = 0 + else: + attrs["window_size_left"] = -1 + attrs["window_size_right"] = -1 if use_flash: headers.append("flash.h") attrs["is_causal"] = int(annotations["custom_mask_type"]) == 2 attrs["num_q_heads"] = annotations["num_q_heads"] attrs["num_kv_heads"] = annotations["num_kv_heads"] - code = instantiate_flash_attention_template(attrs) + + if is_var_len: + code = instantiate_flash_attention_var_len_template(attrs) + else: + code = instantiate_flash_attention_template(attrs) else: headers.append("kernel_forward.h") diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 151e05e9b6..f07a0dfcbb 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -611,13 +611,35 @@ def get_relax_attention_module( return tvm.IRModule({"main": func}) -@memoize("topi.tests.test_codegen_cutlass.test_attention_offload") def get_numpy_attention_ref( - b, s, s_kv, n, h, h_v, bias_shape, qk_scale, causal, dtype, window_size=None + b, + s, + s_kv, + n, + h, + h_v, + bias_shape, + qk_scale, + causal, + dtype, + window_size=None, + num_kv_head=None, ): + if num_kv_head is None: + num_kv_head = n + q = np.random.randn(b, s, n, h).astype(dtype) - k = np.random.randn(b, s_kv, n, h).astype(dtype) - v = np.random.randn(b, s_kv, n, h_v).astype(dtype) + k_orig = np.random.randn(b, s_kv, num_kv_head, h).astype(dtype) + v_orig = np.random.randn(b, s_kv, num_kv_head, h_v).astype(dtype) + + if num_kv_head is None: + k = k_orig + v = v_orig + else: + factor = n // num_kv_head + k = np.repeat(k_orig, factor, axis=2) + v = np.repeat(v_orig, factor, axis=2) + qt = q.transpose(0, 2, 1, 3) # b, n, s, h kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv if not qk_scale == "none": @@ -655,7 +677,7 @@ def get_numpy_attention_ref( vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v ref = attn @ vt # b, n, s, h_v - return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v + return q, k_orig, v_orig, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v def test_attention_offload(attention_size, attention_dtype): @@ -1191,6 +1213,7 @@ def test_attention_rewrite_fp16(): v: R.Tensor((4, 8, 32, 16), dtype="float16"), bias: R.Tensor((4, 32, 16, 8), dtype="float16"), ) -> R.Tensor((4, 16, 32, 16), dtype="float16"): + R.func_attr({"num_input": 4}) with R.dataflow(): lv = R.permute_dims(q, axes=[0, 2, 1, 3]) lv1 = R.reshape(lv, R.shape([128, 16, 8])) @@ -1262,7 +1285,8 @@ def test_attention_rewrite_fp16(): k: R.Tensor((4, 8, 32, 8), dtype="float16"), v: R.Tensor((4, 8, 32, 16), dtype="float16"), bias: R.Tensor((4, 32, 16, 8), dtype="float16"), - ) -> R.Tensor((4, 16, 32, 16), dtype="float16"): + ) -> R.Tensor((4, 16, 32, 16), dtype="float32"): + R.func_attr({"num_input": 4}) cls = Expected with R.dataflow(): lv = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0), R.dtype("uint8")) @@ -2016,49 +2040,10 @@ def test_attention_rewrite_multi_query(): tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) -def test_batched_var_len_attention(): +def _test_batched_var_len_attention(mod, seq_lens, num_head, num_kv_head, head_size): if not tvm.get_global_func("tvm.contrib.thrust.sum_scan", True): return - @I.ir_module - class Module: - @R.function - def main( - queries: R.Tensor(("num_tokens", 4096), dtype="float16"), - keys: R.Tensor(("num_tokens", 4096), dtype="float16"), - values: R.Tensor(("num_tokens", 4096), dtype="float16"), - seq_lens: R.Tensor(("num_seq",), dtype="int32"), - ) -> R.Tensor(("num_tokens", 4096), dtype="float16"): - cls = Module - num_tokens = T.int64() - num_seq = T.int64() - - with R.dataflow(): - # TODO(masahi): Workaround for the broken Relax cumsum op on GPU. - # https://github.com/apache/tvm/issues/15851 - cumsum = R.call_dps_packed( - "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info - ) - max_seqlen_q = R.max(seq_lens) - seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum]) - q = R.reshape(queries, R.shape([1, num_tokens, 128, 32])) - k = R.reshape(keys, R.shape([1, num_tokens, 128, 32])) - v = R.reshape(values, R.shape([1, num_tokens, 128, 32])) - attn_out = R.nn.attention_var_len( - q, - k, - v, - seqstart_q, - max_seqlen_q, - causal_mask="BottomRight", - ) - out = R.reshape(attn_out, R.shape([num_tokens, 4096])) - R.output(out) - return out - - seq_lens = [5, 3, 8] - num_head = 128 - head_size = 32 hidden_size = num_head * head_size batched_queries = [] @@ -2068,11 +2053,21 @@ def test_batched_var_len_attention(): for s in seq_lens: q, k, v, _, ref = get_numpy_attention_ref( - 1, s, s, num_head, head_size, head_size, "none", "none", "BottomRight", "float16" + 1, + s, + s, + num_head, + head_size, + head_size, + "none", + "none", + "BottomRight", + "float16", + num_kv_head=num_kv_head, ) batched_queries.append(np.reshape(q, [-1, hidden_size])) - batched_keys.append(np.reshape(k, [-1, hidden_size])) - batched_values.append(np.reshape(v, [-1, hidden_size])) + batched_keys.append(np.reshape(k, [-1, num_kv_head * head_size])) + batched_values.append(np.reshape(v, [-1, num_kv_head * head_size])) batched_refs.append(np.reshape(ref, [-1, hidden_size])) batched_queries = np.vstack(batched_queries) @@ -2080,7 +2075,7 @@ def test_batched_var_len_attention(): batched_values = np.vstack(batched_values) ref = np.vstack(batched_refs) - mod = partition_for_cutlass(Module) + mod = partition_for_cutlass(mod) codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80}}) mod = codegen_pass(mod) @@ -2099,8 +2094,6 @@ def test_batched_var_len_attention(): "cuda", ) - tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) - ############# xformer reference for verification ############# # attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) @@ -2115,7 +2108,98 @@ def test_batched_var_len_attention(): # ).cpu().numpy()[0] # out = np.reshape(out, [-1, hidden_size]) - # tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +def test_batched_var_len_attention(): + @I.ir_module + class Module: + @R.function + def main( + queries: R.Tensor(("num_tokens", 4096), dtype="float16"), + keys: R.Tensor(("num_tokens", 4096), dtype="float16"), + values: R.Tensor(("num_tokens", 4096), dtype="float16"), + seq_lens: R.Tensor(("num_seq",), dtype="int32"), + ) -> R.Tensor(("num_tokens", 4096), dtype="float16"): + R.func_attr({"num_input": 4}) + cls = Module + num_tokens = T.int64() + num_seq = T.int64() + + with R.dataflow(): + # TODO(masahi): Workaround for the broken Relax cumsum op on GPU. + # https://github.com/apache/tvm/issues/15851 + cumsum = R.call_dps_packed( + "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info + ) + max_seqlen_q = R.max(seq_lens) + seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum]) + q = R.reshape(queries, R.shape([1, num_tokens, 128, 32])) + k = R.reshape(keys, R.shape([1, num_tokens, 128, 32])) + v = R.reshape(values, R.shape([1, num_tokens, 128, 32])) + attn_out = R.nn.attention_var_len( + q, + k, + v, + seqstart_q, + max_seqlen_q, + causal_mask="BottomRight", + ) + out = R.reshape(attn_out, R.shape([num_tokens, 4096])) + R.output(out) + return out + + seq_lens = [5, 3, 8] + num_head = 128 + head_size = 32 + + _test_batched_var_len_attention(Module, seq_lens, num_head, num_head, head_size) + + +def test_batched_var_len_multi_query_attention(): + @I.ir_module + class Module: + @R.function + def main( + queries: R.Tensor(("num_tokens", 4096), dtype="float16"), + keys: R.Tensor(("num_tokens", 512), dtype="float16"), + values: R.Tensor(("num_tokens", 512), dtype="float16"), + seq_lens: R.Tensor(("num_seq",), dtype="int32"), + ) -> R.Tensor(("num_tokens", 4096), dtype="float16"): + R.func_attr({"num_input": 4}) + cls = Module + num_tokens = T.int64() + num_seq = T.int64() + + with R.dataflow(): + # TODO(masahi): Workaround for the broken Relax cumsum op on GPU. + # https://github.com/apache/tvm/issues/15851 + cumsum = R.call_dps_packed( + "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info + ) + max_seqlen_q = R.max(seq_lens) + seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum]) + q = R.reshape(queries, R.shape([1, num_tokens, 128, 32])) + k = R.reshape(keys, R.shape([1, num_tokens, 16, 32])) + v = R.reshape(values, R.shape([1, num_tokens, 16, 32])) + attn_out = R.nn.attention_var_len( + q, + k, + v, + seqstart_q, + max_seqlen_q, + causal_mask="BottomRight", + ) + out = R.reshape(attn_out, R.shape([num_tokens, 4096])) + R.output(out) + return out + + seq_lens = [5, 3, 8] + num_head = 128 + num_kv_head = 16 + head_size = 32 + + _test_batched_var_len_attention(Module, seq_lens, num_head, num_kv_head, head_size) def test_sliding_window():