This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 6333d86105 [KVCache] Support mode "None" for Rotary Embebdding (#16580) 6333d86105 is described below commit 6333d86105e28435070976058f7830a15a751fa3 Author: Ruihang Lai <ruiha...@cs.cmu.edu> AuthorDate: Fri Feb 16 07:54:18 2024 -0500 [KVCache] Support mode "None" for Rotary Embebdding (#16580) This PR supports a "None" Rotary Embedding mode in PagedKVCache. When the mode is None, the rotary embedding will not be applied to when computing attention. --- src/runtime/relax_vm/paged_kv_cache.cc | 6 ++- ..._builtin_paged_attention_kv_cache_flashinfer.py | 45 +++++++++++++++------- ...runtime_builtin_paged_attention_kv_cache_tir.py | 43 ++++++++++++++++----- 3 files changed, 69 insertions(+), 25 deletions(-) diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 7417d90e02..d5ddef7527 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -153,12 +153,14 @@ struct Sequence { /*! * \brief The rotary embedding mode adopted by the paged KV cache * when computing attention. + * "None" means RoPE is never applied to q and k. * "Normal" means RoPE is computed in a standalone kernel. * "Inline" means RoPE is computed on-the-fly in attention kernels. */ enum class RoPEMode : int { - kNormal = 0, - kInline = 1, + kNone = 0, + kNormal = 1, + kInline = 2, }; /*! diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py index 8a40f43ab7..fccef312c1 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import enum from typing import Dict, List, Tuple, Union import numpy as np @@ -322,7 +323,19 @@ def create_kv_cache(rope_mode): return cache -@pytest.fixture(params=[0, 1]) +class RopeMode(enum.IntEnum): + """The RoPE mode of the Paged KV cache. + If it is none, the KV cache will not apply RoPE to q and k. + If it is normal, RoPE will be applied to k before adding k to cache. + Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly. + """ + + NONE = 0 + NORMAL = 1 + INLINE = 2 + + +@pytest.fixture(params=[RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]) def kv_cache_and_rope_mode(request): set_global_func() return create_kv_cache(request.param), request.param @@ -361,7 +374,7 @@ def f_apply_rotary(x, offset, scale, theta): def apply_attention( kv_cache, - rope_mode: int, + rope_mode: RopeMode, batch: List[Tuple[Union[int, Tuple[int, int]], int]], cached_k: Dict[int, np.ndarray], cached_v: Dict[int, np.ndarray], @@ -406,10 +419,12 @@ def apply_attention( cached_k[seq_id], np.stack( [ - new_k[l] - if rope_mode == 1 - else f_apply_rotary( - new_k[l], cached_k[seq_id].shape[1], rope_scale, rope_theta + ( + new_k[l] + if rope_mode != RopeMode.NORMAL + else f_apply_rotary( + new_k[l], cached_k[seq_id].shape[1], rope_scale, rope_theta + ) ) for l in range(num_layers) ], @@ -445,15 +460,19 @@ def apply_attention( assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= append_length rope_offset = cached_k[seq_id].shape[1] - append_length - q_seq = f_apply_rotary( - q_array[i][layer_id], - rope_offset, - rope_scale, - rope_theta, + q_seq = ( + q_array[i][layer_id] + if rope_mode == RopeMode.NONE + else f_apply_rotary( + q_array[i][layer_id], + rope_offset, + rope_scale, + rope_theta, + ) ).transpose(1, 0, 2) k_seq = ( cached_k[seq_id][layer_id] - if rope_mode == 0 + if rope_mode != RopeMode.INLINE else f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, rope_theta) ).transpose(1, 2, 0) v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2) @@ -586,7 +605,7 @@ def test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode, fuse_qkv): if __name__ == "__main__": set_global_func() - for rope_mode in [0, 1]: + for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]: cache = create_kv_cache(rope_mode) for fuse_qkv in [False, True]: test_paged_attention_kv_cache_prefill_and_decode((cache, rope_mode), fuse_qkv) diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index dc4d4082f1..8bd9da3bbb 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import enum import itertools import math from typing import Dict, List, Tuple, Union @@ -140,7 +141,25 @@ def create_kv_cache(head_dim, dtype, rope_mode): return cache -@pytest.fixture(params=itertools.product([64, 128], ["float16", "float32"], [0, 1])) +class RopeMode(enum.IntEnum): + """The RoPE mode of the Paged KV cache. + If it is none, the KV cache will not apply RoPE to q and k. + If it is normal, RoPE will be applied to k before adding k to cache. + Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly. + """ + + NONE = 0 + NORMAL = 1 + INLINE = 2 + + +@pytest.fixture( + params=itertools.product( + [64, 128], + ["float16", "float32"], + [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE], + ) +) def kv_cache_and_rope_mode(request): global head_dim, dtype head_dim, dtype, rope_mode = request.param @@ -181,7 +200,7 @@ def f_apply_rotary(x, offset, scale, theta): def apply_attention( kv_cache, - rope_mode: int, + rope_mode: RopeMode, batch: List[Tuple[Union[int, Tuple[int, int]], int]], cached_k: Dict[int, np.ndarray], cached_v: Dict[int, np.ndarray], @@ -228,7 +247,7 @@ def apply_attention( [ ( new_k[l] - if rope_mode == 1 + if rope_mode != RopeMode.NORMAL else f_apply_rotary( new_k[l], cached_k[seq_id].shape[1], rope_scale, rope_theta ) @@ -267,15 +286,19 @@ def apply_attention( assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= append_length rope_offset = cached_k[seq_id].shape[1] - append_length - q_seq = f_apply_rotary( - q_array[i][layer_id], - rope_offset, - rope_scale, - rope_theta, + q_seq = ( + q_array[i][layer_id] + if rope_mode == RopeMode.NONE + else f_apply_rotary( + q_array[i][layer_id], + rope_offset, + rope_scale, + rope_theta, + ) ).transpose(1, 0, 2) k_seq = ( cached_k[seq_id][layer_id] - if rope_mode == 0 + if rope_mode != RopeMode.INLINE else f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, rope_theta) ).transpose(1, 2, 0) v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2) @@ -1639,7 +1662,7 @@ def _merge_state_inplace(num_heads, head_dim, v_dtype): if __name__ == "__main__": for head_dim in [64, 128]: for dtype in ["float16", "float32"]: - for rope_mode in [0, 1]: + for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]: set_global_func(head_dim, dtype) cache = create_kv_cache(head_dim, dtype, rope_mode) for fuse_qkv in [False, True]: