This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new ccca0f5ecf [Unity][BYOC] Fuse attention pattern with `strided_slice` (#14649) ccca0f5ecf is described below commit ccca0f5ecf75f3bbc38c38d4b559dc8f9ac3fc48 Author: Yaxing Cai <caiyaxing...@gmail.com> AuthorDate: Tue Apr 18 01:23:17 2023 -0700 [Unity][BYOC] Fuse attention pattern with `strided_slice` (#14649) * [Unity][BYOC] Fuse attention pattern with `strided_slice` This PR expands the support for fused stacked attention patterns strating with `strided_slice`. Initially, we only support fused stacked attention pattern starting with `split` in #14608. But with the help of #14583, we may have similar patterns starting with `strided_slice` as well. * remove useless code --- python/tvm/relax/backend/contrib/cutlass.py | 12 ++++++-- python/tvm/relax/backend/patterns.py | 23 +++++++++++---- tests/python/relax/test_codegen_cutlass.py | 44 +++++++++++++++++++++++------ 3 files changed, 64 insertions(+), 15 deletions(-) diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index 4515118f58..06edd9febf 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -247,11 +247,19 @@ def attention_patterns(): ), ( "cutlass.stacked_attention", - *make_stacked_attention_pattern(), + *make_stacked_attention_pattern(start_op="split"), ), ( "cutlass.stacked_attention", - *make_stacked_attention_pattern(with_bias=True), + *make_stacked_attention_pattern(start_op="split", with_bias=True), + ), + ( + "cutlass.stacked_attention", + *make_stacked_attention_pattern(start_op="strided_slice"), + ), + ( + "cutlass.stacked_attention", + *make_stacked_attention_pattern(start_op="strided_slice", with_bias=True), ), ] diff --git a/python/tvm/relax/backend/patterns.py b/python/tvm/relax/backend/patterns.py index 9e34b0c964..6197fe44ca 100644 --- a/python/tvm/relax/backend/patterns.py +++ b/python/tvm/relax/backend/patterns.py @@ -197,12 +197,15 @@ def make_attention_pattern(with_bias: bool = False): return out, annotations -def make_stacked_attention_pattern(with_bias: bool = False): +def make_stacked_attention_pattern(start_op: str, with_bias: bool = False): """ Create pattern for fused multi head attention with stacked input. Parameters ---------- + start_op: str + The starting op for pattern, i.e. `R.split` or `R.strided_slice`. + with_bias: bool Whether or not to include bias addition @@ -217,13 +220,23 @@ def make_stacked_attention_pattern(with_bias: bool = False): check function and codegen. """ stacked_qkv = wildcard() - qkv_tuple = is_op("relax.split")(stacked_qkv) + if start_op == "split": + qkv_tuple = is_op("relax.split")(stacked_qkv) + query_raw = is_tuple_get_item(qkv_tuple, 0) + key_raw = is_tuple_get_item(qkv_tuple, 1) + value_raw = is_tuple_get_item(qkv_tuple, 2) + elif start_op == "strided_slice": + query_raw = is_op("relax.strided_slice")(stacked_qkv) + key_raw = is_op("relax.strided_slice")(stacked_qkv) + value_raw = is_op("relax.strided_slice")(stacked_qkv) + else: + raise NotImplementedError() query_reshape_list = wildcard() key_reshape_list = wildcard() value_reshape_list = wildcard() - query = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 0), query_reshape_list) - key = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 1), key_reshape_list) - value = is_op("relax.reshape")(is_tuple_get_item(qkv_tuple, 2), value_reshape_list) + query = is_op("relax.reshape")(query_raw, query_reshape_list) + key = is_op("relax.reshape")(key_raw, key_reshape_list) + value = is_op("relax.reshape")(value_raw, value_reshape_list) annotations = { "stacked_qkv": stacked_qkv, "query_reshape_list": query_reshape_list, diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 4309627bf0..db8abf34c2 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -660,7 +660,7 @@ def get_numpy_stacked_attention_ref(b, s, n, h, h_v, bias_shape, bias_reshape, q return qkv, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v -def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias=None, qk_scale=None): +def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, op, bias=None, qk_scale=None): dtype = str(qkv.dtype) from tvm.script.ir_builder import IRBuilder @@ -676,10 +676,22 @@ def get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias=None, qk_scale if bias is not None: bias = R.arg("bias", R.Tensor(bias.shape, dtype)) with R.dataflow() as frame: - qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2) - q = R.reshape(qkv_tuple[0], [b, s, n, h]) - k = R.reshape(qkv_tuple[1], [b, s, n, h]) - v = R.reshape(qkv_tuple[2], [b, s, n, h_v]) + if op == "split": + qkv_tuple = R.split(qkv, [n * h, n * h * 2], axis=2) + q = R.reshape(qkv_tuple[0], [b, s, n, h]) + k = R.reshape(qkv_tuple[1], [b, s, n, h]) + v = R.reshape(qkv_tuple[2], [b, s, n, h_v]) + elif op == "strided_slice": + q = R.reshape(R.strided_slice(qkv, [2], [0], [n * h], [1]), [b, s, n, h]) + k = R.reshape( + R.strided_slice(qkv, [2], [n * h], [n * h * 2], [1]), [b, s, n, h] + ) + v = R.reshape( + R.strided_slice(qkv, [2], [n * h * 2], [n * h * 2 + n * h_v], [1]), + [b, s, n, h_v], + ) + else: + raise NotImplementedError() result = R.emit(R.nn.attention(q, k, v, bias, qk_scale)) R.output(result) @@ -700,15 +712,31 @@ def stacked_attention_size(request): return request.param -def test_stacked_attention_offload(stacked_attention_size): +def test_stacked_attention_split_offload(stacked_attention_size): + b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size + qkv, bias, ref = get_numpy_stacked_attention_ref( + b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32" + ) + if scale == "none": + mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "split", bias) + else: + mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "split", bias, scale) + if bias is None: + out = get_result_with_relax_cutlass_offload(mod, qkv) + else: + out = get_result_with_relax_cutlass_offload(mod, qkv, bias) + tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) + + +def test_stacked_attention_strided_slice_offload(stacked_attention_size): b, s, n, (h, h_v), bias_shape, bias_reshape, scale = stacked_attention_size qkv, bias, ref = get_numpy_stacked_attention_ref( b, s, n, h, h_v, bias_shape, bias_reshape, scale, "float32" ) if scale == "none": - mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias) + mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "strided_slice", bias) else: - mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, bias, scale) + mod = get_relax_stacked_attention_module(qkv, b, s, n, h, h_v, "strided_slice", bias, scale) if bias is None: out = get_result_with_relax_cutlass_offload(mod, qkv) else: