This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6333d86105 [KVCache] Support mode "None" for Rotary Embebdding (#16580)
6333d86105 is described below

commit 6333d86105e28435070976058f7830a15a751fa3
Author: Ruihang Lai <ruiha...@cs.cmu.edu>
AuthorDate: Fri Feb 16 07:54:18 2024 -0500

    [KVCache] Support mode "None" for Rotary Embebdding (#16580)
    
    This PR supports a "None" Rotary Embedding mode in
    PagedKVCache. When the mode is None, the rotary embedding
    will not be applied to when computing attention.
---
 src/runtime/relax_vm/paged_kv_cache.cc             |  6 ++-
 ..._builtin_paged_attention_kv_cache_flashinfer.py | 45 +++++++++++++++-------
 ...runtime_builtin_paged_attention_kv_cache_tir.py | 43 ++++++++++++++++-----
 3 files changed, 69 insertions(+), 25 deletions(-)

diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 7417d90e02..d5ddef7527 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -153,12 +153,14 @@ struct Sequence {
 /*!
  * \brief The rotary embedding mode adopted by the paged KV cache
  * when computing attention.
+ * "None" means RoPE is never applied to q and k.
  * "Normal" means RoPE is computed in a standalone kernel.
  * "Inline" means RoPE is computed on-the-fly in attention kernels.
  */
 enum class RoPEMode : int {
-  kNormal = 0,
-  kInline = 1,
+  kNone = 0,
+  kNormal = 1,
+  kInline = 2,
 };
 
 /*!
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
index 8a40f43ab7..fccef312c1 100644
--- 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
+++ 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_flashinfer.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import enum
 from typing import Dict, List, Tuple, Union
 
 import numpy as np
@@ -322,7 +323,19 @@ def create_kv_cache(rope_mode):
     return cache
 
 
-@pytest.fixture(params=[0, 1])
+class RopeMode(enum.IntEnum):
+    """The RoPE mode of the Paged KV cache.
+    If it is none, the KV cache will not apply RoPE to q and k.
+    If it is normal, RoPE will be applied to k before adding k to cache.
+    Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.
+    """
+
+    NONE = 0
+    NORMAL = 1
+    INLINE = 2
+
+
+@pytest.fixture(params=[RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE])
 def kv_cache_and_rope_mode(request):
     set_global_func()
     return create_kv_cache(request.param), request.param
@@ -361,7 +374,7 @@ def f_apply_rotary(x, offset, scale, theta):
 
 def apply_attention(
     kv_cache,
-    rope_mode: int,
+    rope_mode: RopeMode,
     batch: List[Tuple[Union[int, Tuple[int, int]], int]],
     cached_k: Dict[int, np.ndarray],
     cached_v: Dict[int, np.ndarray],
@@ -406,10 +419,12 @@ def apply_attention(
                 cached_k[seq_id],
                 np.stack(
                     [
-                        new_k[l]
-                        if rope_mode == 1
-                        else f_apply_rotary(
-                            new_k[l], cached_k[seq_id].shape[1], rope_scale, 
rope_theta
+                        (
+                            new_k[l]
+                            if rope_mode != RopeMode.NORMAL
+                            else f_apply_rotary(
+                                new_k[l], cached_k[seq_id].shape[1], 
rope_scale, rope_theta
+                            )
                         )
                         for l in range(num_layers)
                     ],
@@ -445,15 +460,19 @@ def apply_attention(
             assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= 
append_length
 
             rope_offset = cached_k[seq_id].shape[1] - append_length
-            q_seq = f_apply_rotary(
-                q_array[i][layer_id],
-                rope_offset,
-                rope_scale,
-                rope_theta,
+            q_seq = (
+                q_array[i][layer_id]
+                if rope_mode == RopeMode.NONE
+                else f_apply_rotary(
+                    q_array[i][layer_id],
+                    rope_offset,
+                    rope_scale,
+                    rope_theta,
+                )
             ).transpose(1, 0, 2)
             k_seq = (
                 cached_k[seq_id][layer_id]
-                if rope_mode == 0
+                if rope_mode != RopeMode.INLINE
                 else f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, 
rope_theta)
             ).transpose(1, 2, 0)
             v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2)
@@ -586,7 +605,7 @@ def 
test_paged_attention_kv_cache_popn(kv_cache_and_rope_mode, fuse_qkv):
 
 if __name__ == "__main__":
     set_global_func()
-    for rope_mode in [0, 1]:
+    for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]:
         cache = create_kv_cache(rope_mode)
         for fuse_qkv in [False, True]:
             test_paged_attention_kv_cache_prefill_and_decode((cache, 
rope_mode), fuse_qkv)
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index dc4d4082f1..8bd9da3bbb 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import enum
 import itertools
 import math
 from typing import Dict, List, Tuple, Union
@@ -140,7 +141,25 @@ def create_kv_cache(head_dim, dtype, rope_mode):
     return cache
 
 
-@pytest.fixture(params=itertools.product([64, 128], ["float16", "float32"], 
[0, 1]))
+class RopeMode(enum.IntEnum):
+    """The RoPE mode of the Paged KV cache.
+    If it is none, the KV cache will not apply RoPE to q and k.
+    If it is normal, RoPE will be applied to k before adding k to cache.
+    Otherwise, RoPE will be applied to q/k in attention kernel on-the-fly.
+    """
+
+    NONE = 0
+    NORMAL = 1
+    INLINE = 2
+
+
+@pytest.fixture(
+    params=itertools.product(
+        [64, 128],
+        ["float16", "float32"],
+        [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE],
+    )
+)
 def kv_cache_and_rope_mode(request):
     global head_dim, dtype
     head_dim, dtype, rope_mode = request.param
@@ -181,7 +200,7 @@ def f_apply_rotary(x, offset, scale, theta):
 
 def apply_attention(
     kv_cache,
-    rope_mode: int,
+    rope_mode: RopeMode,
     batch: List[Tuple[Union[int, Tuple[int, int]], int]],
     cached_k: Dict[int, np.ndarray],
     cached_v: Dict[int, np.ndarray],
@@ -228,7 +247,7 @@ def apply_attention(
                     [
                         (
                             new_k[l]
-                            if rope_mode == 1
+                            if rope_mode != RopeMode.NORMAL
                             else f_apply_rotary(
                                 new_k[l], cached_k[seq_id].shape[1], 
rope_scale, rope_theta
                             )
@@ -267,15 +286,19 @@ def apply_attention(
             assert cached_k[seq_id].shape[1] == cached_v[seq_id].shape[1] >= 
append_length
 
             rope_offset = cached_k[seq_id].shape[1] - append_length
-            q_seq = f_apply_rotary(
-                q_array[i][layer_id],
-                rope_offset,
-                rope_scale,
-                rope_theta,
+            q_seq = (
+                q_array[i][layer_id]
+                if rope_mode == RopeMode.NONE
+                else f_apply_rotary(
+                    q_array[i][layer_id],
+                    rope_offset,
+                    rope_scale,
+                    rope_theta,
+                )
             ).transpose(1, 0, 2)
             k_seq = (
                 cached_k[seq_id][layer_id]
-                if rope_mode == 0
+                if rope_mode != RopeMode.INLINE
                 else f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, 
rope_theta)
             ).transpose(1, 2, 0)
             v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2)
@@ -1639,7 +1662,7 @@ def _merge_state_inplace(num_heads, head_dim, v_dtype):
 if __name__ == "__main__":
     for head_dim in [64, 128]:
         for dtype in ["float16", "float32"]:
-            for rope_mode in [0, 1]:
+            for rope_mode in [RopeMode.NONE, RopeMode.NORMAL, RopeMode.INLINE]:
                 set_global_func(head_dim, dtype)
                 cache = create_kv_cache(head_dim, dtype, rope_mode)
                 for fuse_qkv in [False, True]:

Reply via email to