This is an automated email from the ASF dual-hosted git repository.

hongyij pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8b4df725b7 [Runtime][KVCache] Initial interface setup for MLA (#17616)
8b4df725b7 is described below

commit 8b4df725b797a05e87d80d165dc7bbb6774aa869
Author: Ruihang Lai <[email protected]>
AuthorDate: Thu Jan 30 20:00:22 2025 -0500

    [Runtime][KVCache] Initial interface setup for MLA (#17616)
    
    This PR introduces the initial KV cache interface setup for multi-head
    latent attention in DeepSeek models.
    
    Some interface implementations are marked todo for implementation
    in the soon future.
---
 src/runtime/relax_vm/kv_state.h        |  63 +++++++
 src/runtime/relax_vm/paged_kv_cache.cc | 313 ++++++++++++++++++++++++++++-----
 2 files changed, 330 insertions(+), 46 deletions(-)

diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h
index 7df3215d08..77c17d1c55 100644
--- a/src/runtime/relax_vm/kv_state.h
+++ b/src/runtime/relax_vm/kv_state.h
@@ -181,6 +181,69 @@ class AttentionKVCacheObj : public KVStateObj {
   virtual void AttentionWithFusedQKV(int64_t layer_id, NDArray qkv_data, 
Optional<NDArray> mask,
                                      NDArray o_data, double 
attn_score_scaling_factor) = 0;
 
+  /*!
+   * \brief Compute attention with Q/K/V data.
+   * \param layer_id The model layer where the attention compute happens.
+   * \param q_data The input Q data, in layout `(total_length, num_qo_heads, 
head_dim)`
+   * \param k_data The input K data, in layout `(total_length, num_kv_heads, 
head_dim)`
+   * \param v_data The input V data, in layout `(total_length, num_kv_heads, 
head_dim)`
+   * \param mask The input mask data, in layout `(total_sqr_length)`.
+   * \param o_data The output O data, in layout `(total_length, num_qo_heads, 
head_dim)`.
+   * \param attn_score_scaling_factor The additional attention scaling factor.
+   */
+  virtual void AttentionWithSeparateQKV(int64_t layer_id, NDArray q_data, 
NDArray k_data,
+                                        NDArray v_data, Optional<NDArray> 
mask, NDArray o_data,
+                                        double attn_score_scaling_factor) = 0;
+
+  /*!
+   * \brief Compute multi-head latent attention after applying weight 
absorption.
+   * \param layer_id The model layer where the attention compute happens.
+   * \param q_data The input Q data, in layout `(total_length, num_qo_heads, 
qk_head_dim)`
+   * \param compressed_kv_data The compressed latent KV data, in layout
+   * `(total_length, num_kv_heads, kv_lora_rank)`
+   * \param k_pe_data The positional embedding part of K data, in layout
+   * `(total_length, num_kv_heads, qk_rope_head_dim)`, where `kv_lora_rank + 
qk_rope_head_dim`
+   * equals qk_head_dim
+   * \param o_data The output O data, in layout `(total_length, num_qo_heads, 
v_head_dim)`.
+   * \param attn_score_scaling_factor The additional attention scaling factor.
+   */
+  virtual void MLAAbsorbed(int64_t layer_id, NDArray q_data, NDArray 
compressed_kv_data,
+                           NDArray k_pe_data, NDArray o_data, double 
attn_score_scaling_factor) = 0;
+
+  /*!
+   * \brief Compute multi-head latent attention in normal style.
+   * \param layer_id The model layer where the attention compute happens.
+   * \param q_data The input Q data, in layout
+   * `(total_length, num_qo_heads, qk_nope_head_dim + qk_rope_head_dim)`
+   * \param k_data The input K data, in layout
+   * `(total_length, num_qo_heads, qk_nope_head_dim + qk_rope_head_dim)`
+   * \param v_data The input V data, in layout
+   * `(total_length, num_qo_heads, v_head_dim)`
+   * \param compressed_kv_data The compressed latent KV data, in layout
+   * `(total_length, num_kv_heads, kv_lora_rank)`
+   * \param k_pe_data The positional embedding part of K data, in layout
+   * `(total_length, num_kv_heads, qk_rope_head_dim)`, where `kv_lora_rank + 
qk_rope_head_dim`
+   * equals qk_head_dim
+   * \param o_data The output O data, in layout `(total_length, num_qo_heads, 
v_head_dim)`.
+   * \param attn_score_scaling_factor The additional attention scaling factor.
+   */
+  virtual void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, 
NDArray v_data,
+                         NDArray compressed_kv_data, NDArray k_pe_data, 
NDArray o_data,
+                         double attn_score_scaling_factor) = 0;
+
+  /*!
+   * \brief Compute linear attention with Q/K/V data.
+   * \param layer_id The model layer where the attention compute happens.
+   * \param q_data The input Q data, in layout `(total_length, num_qo_heads, 
head_dim)`.
+   * \param k_data The input K data, in layout `(total_length, num_kv_heads, 
head_dim)`.
+   * \param v_data The input V data, in layout `(total_length, num_kv_heads, 
head_dim)`.
+   * \param o_data The output O data, in layout `(total_length, num_qo_heads, 
head_dim)`.
+   * \param attn_score_scaling_factor The additional attention scaling factor.
+   * \sa AttentionKVCache::Attention
+   */
+  virtual void LinearAttention(int64_t layer_id, NDArray q_data, NDArray 
k_data, NDArray v_data,
+                               double attn_score_scaling_factor) = 0;
+
   /************** Positions **************/
 
   /*!
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 81c55bfcb6..8e5dfb4bd8 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -64,6 +64,33 @@ constexpr const int kFloatAttnWorkspaceByte = 768 * 1024 * 
1024;
 /*! \brief The id of the temporary logical page, which is useful for sliding 
window. */
 constexpr const int kPagedKVCacheTempPageId = -1;
 
+/*!
+ * \brief The supported attention kinds in PagedKVCache.
+ * "MHA" means multi-head attention, multi-query attention and grouped query 
attention in general.
+ * "MLA" means multi-head latent attention.
+ * "LinearAttn" means linear attention.
+ */
+enum class AttnKind : int {
+  kMHA = 0,
+  kMLA = 1,
+  kLinearAttn = 2,
+};
+
+ShapeTuple GetKVCacheShape(AttnKind attn_kind, int64_t num_total_pages, int 
num_sequence,
+                           int64_t num_kv_heads, int64_t page_size, int64_t 
qk_head_dim,
+                           int64_t v_head_dim, int64_t qk_rope_head_dim) {
+  if (attn_kind == AttnKind::kMHA) {
+    // Ignore v_head_dim since multi-head attention requires K/V to have the 
same head dim.
+    return {num_total_pages, 2, num_kv_heads, page_size, qk_head_dim};
+  } else if (attn_kind == AttnKind::kMLA) {
+    return {num_total_pages, num_kv_heads, page_size, qk_head_dim + 
qk_rope_head_dim};
+  } else if (attn_kind == AttnKind::kLinearAttn) {
+    return {num_sequence, num_kv_heads, qk_head_dim, v_head_dim};
+  }
+  ICHECK(false);
+  throw;
+}
+
 /*!
  * \brief The block structure in paged KV cache with common prefix support.
  * Each block contains a list of pages for cached KV data.
@@ -940,13 +967,25 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   /*! \brief The number of key/value heads in the model. */
   const int64_t num_kv_heads_;
   /*! \brief The number of features each head has. */
-  const int64_t head_dim_;
+  const int64_t qk_head_dim_;
+  /*!
+   * \brief The number of features each head has for V.
+   * For layers that use multi-head attention, this field is overriden by 
qk_head_dim.
+   */
+  const int64_t v_head_dim_;
+  /*!
+   * \brief The number of features each head has for RoPE in multi-head latent 
attention.
+   * This field is ignored for non-MLA.
+   */
+  const int64_t qk_rope_head_dim_;
   /*! \brief The number of total pages allocated in KV cache. */
   const int64_t num_total_pages_;
   /*! \brief The maximum total sequence length in a prefill. */
   const int64_t prefill_chunk_size_;
   /*! \brief A boolean flag indicating if the KV cache supports sliding 
window. */
   const bool support_sliding_window_;
+  /*! \brief The attention kinds for each layer. */
+  const std::vector<AttnKind> attn_kinds_;
 
   /*! \brief The RoPE application mode of KV cache.*/
   const RoPEMode rope_mode_;
@@ -967,7 +1006,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
    * If KV transfer function is specifed, pages_ will be allocated by NVSHMEM 
as a whole NDArray.
    * pages_ will contain tensor view of each layer.
    * Otherwise, pages_ has `num_layers` NDArrays, each of them
-   * has layout (num_pages, 2, num_heads, page_size, head_dim).
+   * has layout (num_pages, 2, num_heads, page_size, qk_head_dim).
    * Along on the "2" dimension, index 0 stands for K and 1 stands for V.
    */
   std::vector<NDArray> pages_;
@@ -1086,6 +1125,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   std::vector<NDArray> tree_attn_mn_indptr_view_;
 
   PackedFunc f_transpose_append_;
+  PackedFunc f_transpose_append_mla_;
   Optional<PackedFunc> f_transfer_kv_;
   Optional<PackedFunc> f_transfer_kv_page_to_page_ = NullOpt;
   PackedFunc f_compact_copy_;
@@ -1102,8 +1142,13 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   Optional<PackedFunc> f_attention_prefill_end_forward_;
   Optional<PackedFunc> f_attention_decode_begin_forward_;
   Optional<PackedFunc> f_attention_decode_end_forward_;
+  PackedFunc f_mla_prefill_;
+  PackedFunc f_mla_decode_;
+  PackedFunc f_mla_prefill_ragged_normal_;
+  PackedFunc f_mla_prefill_ragged_absorbed_;
   PackedFunc f_merge_inplace_;
   PackedFunc f_split_rotary_;
+  PackedFunc f_separate_rotary_;
   PackedFunc f_copy_single_page_;
   Optional<PackedFunc> f_debug_get_kv_;
 
@@ -1120,37 +1165,45 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   /*! \brief Constructor. Take the cache configuration and initialize the 
NDArrays. */
   explicit PagedAttentionKVCacheObj(
       int64_t page_size, int64_t num_layers, int64_t layer_id_begin_offset,  //
-      int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim, int64_t 
reserved_num_seqs,
+      int64_t num_qo_heads, int64_t num_kv_heads, int64_t qk_head_dim, int64_t 
v_head_dim,
+      int64_t qk_rope_head_dim, std::vector<AttnKind> attn_kinds, int64_t 
reserved_num_seqs,
       int64_t num_total_pages, int64_t prefill_chunk_size, bool 
support_sliding_window,
       RoPEMode rope_mode, double rotary_scale, double rotary_theta,
       Optional<NDArray> rope_ext_factors, bool enable_kv_transfer, DLDataType 
dtype, Device device,
-      PackedFunc f_transpose_append, PackedFunc f_compact_copy, PackedFunc 
f_attention_prefill,
-      PackedFunc f_attention_decode, PackedFunc 
f_attention_prefill_sliding_window,
-      PackedFunc f_attention_decode_sliding_window, PackedFunc 
f_attention_prefill_ragged,
-      PackedFunc f_attention_prefill_with_tree_mask,
+      PackedFunc f_transpose_append, PackedFunc f_transpose_append_mla, 
PackedFunc f_compact_copy,
+      PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
+      PackedFunc f_attention_prefill_sliding_window, PackedFunc 
f_attention_decode_sliding_window,
+      PackedFunc f_attention_prefill_ragged, PackedFunc 
f_attention_prefill_with_tree_mask,
       PackedFunc f_attention_prefill_with_tree_mask_paged_kv,
       Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
       Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
       Optional<PackedFunc> f_attention_prefill_begin_forward,
       Optional<PackedFunc> f_attention_prefill_end_forward,
       Optional<PackedFunc> f_attention_decode_begin_forward,
-      Optional<PackedFunc> f_attention_decode_end_forward, PackedFunc 
f_merge_inplace,
-      PackedFunc f_split_rotary, PackedFunc f_copy_single_page, 
Optional<PackedFunc> f_debug_get_kv)
+      Optional<PackedFunc> f_attention_decode_end_forward, PackedFunc 
f_mla_prefill,
+      PackedFunc f_mla_decode, PackedFunc f_mla_prefill_ragged_normal,
+      PackedFunc f_mla_prefill_ragged_absorbed, PackedFunc f_merge_inplace,
+      PackedFunc f_split_rotary, PackedFunc f_separate_rotary, PackedFunc 
f_copy_single_page,
+      Optional<PackedFunc> f_debug_get_kv)
       : page_size_(page_size),
         num_layers_(num_layers),
         layer_id_begin_offset_(layer_id_begin_offset),
         num_qo_heads_(num_qo_heads),
         num_kv_heads_(num_kv_heads),
-        head_dim_(head_dim),
+        qk_head_dim_(qk_head_dim),
+        v_head_dim_(v_head_dim),
+        qk_rope_head_dim_(qk_rope_head_dim),
         num_total_pages_(num_total_pages),
         prefill_chunk_size_(prefill_chunk_size),
         support_sliding_window_(support_sliding_window),
+        attn_kinds_(std::move(attn_kinds)),
         rope_mode_(support_sliding_window && rope_mode != RoPEMode::kNone ? 
RoPEMode::kInline
                                                                           : 
rope_mode),
         rotary_scale_(rotary_scale),
         rotary_theta_(rotary_theta),
         rope_ext_factors_(std::move(rope_ext_factors)),
         f_transpose_append_(std::move(f_transpose_append)),
+        f_transpose_append_mla_(std::move(f_transpose_append_mla)),
         f_compact_copy_(std::move(f_compact_copy)),
         f_attention_prefill_(std::move(f_attention_prefill)),
         f_attention_decode_(std::move(f_attention_decode)),
@@ -1167,24 +1220,33 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         
f_attention_prefill_end_forward_(std::move(f_attention_prefill_end_forward)),
         
f_attention_decode_begin_forward_(std::move(f_attention_decode_begin_forward)),
         
f_attention_decode_end_forward_(std::move(f_attention_decode_end_forward)),
+        f_mla_prefill_(std::move(f_mla_prefill)),
+        f_mla_decode_(std::move(f_mla_decode)),
+        f_mla_prefill_ragged_normal_(std::move(f_mla_prefill_ragged_normal)),
+        
f_mla_prefill_ragged_absorbed_(std::move(f_mla_prefill_ragged_absorbed)),
         f_merge_inplace_(std::move(f_merge_inplace)),
         f_split_rotary_(std::move(f_split_rotary)),
+        f_separate_rotary_(std::move(f_separate_rotary)),
         f_copy_single_page_(std::move(f_copy_single_page)),
         f_debug_get_kv_(std::move(f_debug_get_kv)),
         device_(device) {
     pages_.reserve(num_layers);
     if (enable_kv_transfer) {
+      // For now, KV transfer only supports MHA.
+      for (AttnKind attn_kind : attn_kinds_) {
+        CHECK(attn_kind == AttnKind::kMHA);
+      }
       CHECK(Registry::Get("runtime.disco.nvshmem.init_nvshmem") != nullptr)
           << "NVSHMEM is not enabled. Please make sure NVSHMEM is enabled when 
compiling TVM.";
       const PackedFunc* f_nvshmem_empty = 
runtime::Registry::Get("runtime.disco.nvshmem.empty");
       ICHECK_NOTNULL(f_nvshmem_empty);
       nvshmem_pages_ = (*f_nvshmem_empty)(
-          ShapeTuple({num_layers, num_total_pages, 2, num_kv_heads, page_size, 
head_dim}), dtype,
+          ShapeTuple({num_layers, num_total_pages, 2, num_kv_heads, page_size, 
qk_head_dim}), dtype,
           device);
       for (int i = 0; i < num_layers; ++i) {
         pages_.push_back(nvshmem_pages_.CreateView(
-            {num_total_pages_, 2, num_kv_heads_, page_size_, head_dim_}, 
nvshmem_pages_->dtype,
-            i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ * head_dim_ *
+            {num_total_pages_, 2, num_kv_heads_, page_size_, qk_head_dim_}, 
nvshmem_pages_->dtype,
+            i * num_total_pages_ * 2 * num_kv_heads_ * page_size_ * 
qk_head_dim_ *
                 nvshmem_pages_.DataType().bytes()));
       }
 
@@ -1197,8 +1259,10 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       f_transfer_kv_page_to_page_ = *f_transfer_kv_page_to_page_ptr;
     } else {
       for (int i = 0; i < num_layers; ++i) {
-        pages_.push_back(
-            NDArray::Empty({num_total_pages, 2, num_kv_heads, page_size, 
head_dim}, dtype, device));
+        ShapeTuple kv_cache_shape = GetKVCacheShape(
+            attn_kinds_[layer_id_begin_offset_ + i], num_total_pages, 
reserved_num_seqs,
+            num_kv_heads, page_size, qk_head_dim, v_head_dim, 
qk_rope_head_dim);
+        pages_.push_back(NDArray::Empty(kv_cache_shape, dtype, device));
       }
     }
 
@@ -1274,13 +1338,13 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     }
 
     temp_attn_q_device_ =
-        NDArray::Empty({prefill_chunk_size_, num_qo_heads, head_dim}, dtype, 
device);
+        NDArray::Empty({prefill_chunk_size_, num_qo_heads, qk_head_dim}, 
dtype, device);
     temp_attn_k_device_ =
-        NDArray::Empty({prefill_chunk_size_, num_kv_heads, head_dim}, dtype, 
device);
+        NDArray::Empty({prefill_chunk_size_, num_kv_heads, qk_head_dim}, 
dtype, device);
     temp_attn_v_device_ =
-        NDArray::Empty({prefill_chunk_size_, num_kv_heads, head_dim}, dtype, 
device);
+        NDArray::Empty({prefill_chunk_size_, num_kv_heads, v_head_dim}, dtype, 
device);
     temp_attn_output_device_ =
-        NDArray::Empty({prefill_chunk_size_, num_qo_heads, head_dim}, dtype, 
device);
+        NDArray::Empty({prefill_chunk_size_, num_qo_heads, qk_head_dim}, 
dtype, device);
     temp_attn_scores_device_ =
         NDArray::Empty({prefill_chunk_size_, num_qo_heads}, 
DataType::Float(32), device);
     merged_attn_scores_device_ =
@@ -2019,8 +2083,9 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     CHECK(qkv_data.DataType() == pages.DataType());
     CHECK(o_data.DataType() == pages.DataType());
 
-    // qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, head_dim)
-    // o_data: (num_total_length, num_qo_heads, head_dim)
+    CHECK(attn_kinds_[layer_id] == AttnKind::kMHA);
+    // qkv_data: (num_total_length, num_qo_heads + 2 * num_kv_heads, 
qk_head_dim)
+    // o_data: (num_total_length, num_qo_heads, qk_head_dim)
 
     CHECK_EQ(qkv_data->ndim, 3);
     CHECK_EQ(o_data->ndim, 3);
@@ -2033,7 +2098,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       }
     }
 
-    CHECK_EQ(qkv_data->shape[2], head_dim_);
+    CHECK_EQ(qkv_data->shape[2], qk_head_dim_);
     int64_t total_seq_length = 0;
     for (int64_t seq_id = 0; seq_id < cur_batch_size_; ++seq_id) {
       total_seq_length += cur_append_lengths_[seq_id];
@@ -2044,11 +2109,11 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     // The auxiliary data structure on device must have been synchronized.
     ICHECK(!dirty_aux_data_device_);
 
-    NDArray q_data = temp_attn_q_device_.CreateView({total_seq_length, 
num_qo_heads_, head_dim_},
+    NDArray q_data = temp_attn_q_device_.CreateView({total_seq_length, 
num_qo_heads_, qk_head_dim_},
                                                     qkv_data->dtype);
-    NDArray k_data = temp_attn_k_device_.CreateView({total_seq_length, 
num_kv_heads_, head_dim_},
+    NDArray k_data = temp_attn_k_device_.CreateView({total_seq_length, 
num_kv_heads_, qk_head_dim_},
                                                     qkv_data->dtype);
-    NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, 
num_kv_heads_, head_dim_},
+    NDArray v_data = temp_attn_v_device_.CreateView({total_seq_length, 
num_kv_heads_, qk_head_dim_},
                                                     qkv_data->dtype);
 
     NDArray qkv_data_view = qkv_data;
@@ -2057,7 +2122,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       qkv_data_view = qkv_data.CreateView(
           {total_seq_length, qkv_data->shape[1], qkv_data->shape[2]}, 
qkv_data->dtype);
       o_data_view =
-          o_data.CreateView({total_seq_length, num_qo_heads_, head_dim_}, 
qkv_data->dtype);
+          o_data.CreateView({total_seq_length, num_qo_heads_, qk_head_dim_}, 
qkv_data->dtype);
     }
     // Part 2. Split fused qkv and apply rotary embedding to q/k data.
     if (transfer_kv_) {
@@ -2105,6 +2170,28 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     }
   }
 
+  void AttentionWithSeparateQKV(int64_t layer_id, NDArray q_data, NDArray 
k_data, NDArray v_data,
+                                Optional<NDArray> mask, NDArray o_data,
+                                double attn_score_scaling_factor) final {
+    // Todo(ruihang): implement it
+  }
+
+  void MLAAbsorbed(int64_t layer_id, NDArray q_data, NDArray 
compressed_kv_data, NDArray k_pe_data,
+                   NDArray o_data, double attn_score_scaling_factor) {
+    // Todo(ruihang): implement it
+  }
+
+  void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray 
v_data,
+                 NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data,
+                 double attn_score_scaling_factor) {
+    // Todo(ruihang): implement it
+  }
+
+  void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data, 
NDArray v_data,
+                       double attn_score_scaling_factor) {
+    // Todo(ruihang): implement it
+  }
+
   void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids, const IntTuple& 
leaf_indices) final {
     CHECK_EQ(seq_ids.size(), leaf_indices.size())
         << "The given seq_ids and leaf_indices have different size.";
@@ -2216,9 +2303,10 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     CHECK_LE(end_pos, seq.seq_length) << "DebugGetKV does not accept 
out-of-range end_pos";
     CHECK_LT(start_pos, end_pos) << "DebugGetKV does not accept \"start_pos >= 
end_pos\"";
 
-    // k/v_data: (num_layers, seq_length, num_kv_heads, head_dim)
+    // k/v_data: (num_layers, seq_length, num_kv_heads, qk_head_dim)
     static constexpr const char* error_msg =
-        "DebugGetKV expects the k_data in layout (num_layers, seq_length, 
num_kv_heads, head_dim).";
+        "DebugGetKV expects the k_data in layout (num_layers, seq_length, 
num_kv_heads, "
+        "qk_head_dim).";
     std::vector<NDArray*> vec_kv_data = {&k_data, &v_data};
     for (const NDArray* data_ptr : vec_kv_data) {
       CHECK_EQ((*data_ptr)->ndim, 4) << error_msg;
@@ -2228,7 +2316,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
           << error_msg << " The sequence length mismatches.";
       CHECK_EQ((*data_ptr)->shape[2], num_kv_heads_)
           << error_msg << " The number of heads mismatches.";
-      CHECK_EQ((*data_ptr)->shape[3], head_dim_)
+      CHECK_EQ((*data_ptr)->shape[3], qk_head_dim_)
           << error_msg << " The number of head features mismatches.";
     }
 
@@ -2250,6 +2338,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         append_position_map.data() + start_pos,
         (end_pos - start_pos) * ((dtype_aux_.bits * dtype_aux_.lanes + 7) / 
8));
     for (int64_t layer_id = 0; layer_id < num_layers_; ++layer_id) {
+      CHECK(attn_kinds_[layer_id] == AttnKind::kMHA) << "Only MHA is supported 
for DebugGetKV";
       f_debug_get_kv_.value()(pages_[layer_id], position_map_device, k_data, 
v_data, layer_id);
     }
   }
@@ -2649,7 +2738,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
             temp_float_attn_workspace_, temp_int_attn_workspace_[0],
             cur_append_lengths_indptr_host_.as_ndarray(),
             cur_append_lengths_indptr_host_.as_ndarray(), cur_batch_size_, 
num_qo_heads_,
-            num_kv_heads_, head_dim_, copy_stream_);
+            num_kv_heads_, qk_head_dim_, copy_stream_);
       }
     }
     for (int d = 0; d < num_depths_; ++d) {
@@ -2661,15 +2750,15 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         f_attention_decode_begin_forward_.value()(
             d, temp_float_attn_workspace_, temp_int_attn_workspace_[d + 1],
             page_indptr_on_depths_host_[d].as_ndarray(),
-            last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, 
num_kv_heads_, head_dim_,
-            page_size_,
+            last_page_len_on_depths_host_[d].as_ndarray(), num_qo_heads_, 
num_kv_heads_,
+            qk_head_dim_, page_size_,
             /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
       } else {
         f_attention_prefill_begin_forward_.value()(
             /*depth=*/d, temp_float_attn_workspace_, 
temp_int_attn_workspace_[d + 1],
             qo_indptr_on_depths_host_[d].as_ndarray(), 
page_indptr_on_depths_host_[d].as_ndarray(),
             static_cast<int>(page_indptr_on_depths_host_[d].size()) - 1, 
num_qo_heads_,
-            num_kv_heads_, head_dim_, page_size_, copy_stream_);
+            num_kv_heads_, qk_head_dim_, page_size_, copy_stream_);
       }
     }
   }
@@ -2893,7 +2982,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     }
     // 16. Create view for temporary arrays for attention computation.
     temp_attn_output_view_ = temp_attn_output_device_.CreateView(
-        {total_append_length, num_qo_heads_, head_dim_}, 
temp_attn_output_device_->dtype);
+        {total_append_length, num_qo_heads_, qk_head_dim_}, 
temp_attn_output_device_->dtype);
     temp_attn_scores_view_ = temp_attn_scores_device_.CreateView(
         {total_append_length, num_qo_heads_}, temp_attn_scores_device_->dtype);
     merged_attn_scores_view_ = merged_attn_scores_device_.CreateView(
@@ -2964,6 +3053,9 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
         enable_kv_transfer = args[29];
       }
 
+      std::vector<AttnKind> attn_kinds(/*size=*/layer_indptr_tuple[num_groups],
+                                       /*value=*/AttnKind::kMHA);
+
       CHECK_EQ(cache_config.size(), 5);
       int64_t reserved_num_seqs = cache_config[0];
       int64_t total_token_capacity = cache_config[1];
@@ -2975,13 +3067,18 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
         // When sliding window is enabled, each sequence may use two more 
pages at most.
         num_total_pages += reserved_num_seqs * 2;
       }
+      // NOTE: We will remove this legacy construction after finishing the 
transition phase.
+      // Some `PackedFunc()` here are placeholders that will be filled.
       ObjectPtr<PagedAttentionKVCacheObj> n = 
make_object<PagedAttentionKVCacheObj>(
           page_size, num_layers, layer_id_begin_offset, num_qo_heads, 
num_kv_heads, head_dim,
-          reserved_num_seqs, num_total_pages, prefill_chunk_size, 
support_sliding_window,
-          RoPEMode(rope_mode), rotary_scale, rotary_theta, 
std::move(rope_ext_factors),  //
-          enable_kv_transfer, init->dtype, init->device,                       
          //
-          std::move(f_transpose_append), std::move(f_compact_copy), 
std::move(f_attention_prefill),
-          std::move(f_attention_decode), 
std::move(f_attention_prefill_sliding_window),
+          head_dim, /*qk_rope_head_dim=*/0, attn_kinds, reserved_num_seqs, 
num_total_pages,
+          prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), 
rotary_scale,
+          rotary_theta,
+          std::move(rope_ext_factors),                    //
+          enable_kv_transfer, init->dtype, init->device,  //
+          std::move(f_transpose_append), PackedFunc(), 
std::move(f_compact_copy),
+          std::move(f_attention_prefill), std::move(f_attention_decode),
+          std::move(f_attention_prefill_sliding_window),
           std::move(f_attention_decode_sliding_window), 
std::move(f_attention_prefill_ragged),
           std::move(f_attention_prefill_with_tree_mask),
           std::move(f_attention_prefill_with_tree_mask_paged_kv),
@@ -2989,7 +3086,8 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
           std::move(f_attention_prefill_ragged_end_forward),
           std::move(f_attention_prefill_begin_forward), 
std::move(f_attention_prefill_end_forward),
           std::move(f_attention_decode_begin_forward), 
std::move(f_attention_decode_end_forward),
-          std::move(f_merge_inplace), std::move(f_split_rotary), 
std::move(f_copy_single_page),
+          PackedFunc(), PackedFunc(), PackedFunc(), PackedFunc(), 
std::move(f_merge_inplace),
+          std::move(f_split_rotary), PackedFunc(), 
std::move(f_copy_single_page),
           std::move(f_debug_get_kv));
       *rv = AttentionKVCache(std::move(n));
     });
@@ -3040,6 +3138,9 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
         enable_kv_transfer = args[23];
       }
 
+      std::vector<AttnKind> attn_kinds(/*size=*/layer_indptr_tuple[num_groups],
+                                       /*value=*/AttnKind::kMHA);
+
       CHECK_EQ(cache_config.size(), 5);
       int64_t reserved_num_seqs = cache_config[0];
       int64_t total_token_capacity = cache_config[1];
@@ -3051,18 +3152,138 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
         // When sliding window is enabled, each sequence may use two more 
pages at most.
         num_total_pages += reserved_num_seqs * 2;
       }
+      // NOTE: We will remove this legacy construction after finishing the 
transition phase.
+      // Some `PackedFunc()` here are placeholders that will be filled.
       ObjectPtr<PagedAttentionKVCacheObj> n = 
make_object<PagedAttentionKVCacheObj>(
           page_size, num_layers, layer_id_begin_offset, num_qo_heads, 
num_kv_heads, head_dim,
-          reserved_num_seqs, num_total_pages, prefill_chunk_size, 
support_sliding_window,
-          RoPEMode(rope_mode), rotary_scale, rotary_theta, 
std::move(rope_ext_factors),  //
-          enable_kv_transfer, init->dtype, init->device,                       
          //
-          std::move(f_transpose_append), std::move(f_compact_copy), 
std::move(f_attention_prefill),
-          std::move(f_attention_decode), 
std::move(f_attention_prefill_sliding_window),
+          head_dim, /*qk_rope_head_dim=*/0, attn_kinds, reserved_num_seqs, 
num_total_pages,
+          prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), 
rotary_scale,
+          rotary_theta,
+          std::move(rope_ext_factors),                    //
+          enable_kv_transfer, init->dtype, init->device,  //
+          std::move(f_transpose_append), PackedFunc(), 
std::move(f_compact_copy),
+          std::move(f_attention_prefill), std::move(f_attention_decode),
+          std::move(f_attention_prefill_sliding_window),
           std::move(f_attention_decode_sliding_window), 
std::move(f_attention_prefill_ragged),
           std::move(f_attention_prefill_with_tree_mask),           //
           std::move(f_attention_prefill_with_tree_mask_paged_kv),  //
           NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt,    //
-          std::move(f_merge_inplace), std::move(f_split_rotary), 
std::move(f_copy_single_page),
+          PackedFunc(), PackedFunc(), PackedFunc(), PackedFunc(), 
std::move(f_merge_inplace),
+          std::move(f_split_rotary), PackedFunc(), 
std::move(f_copy_single_page),
+          std::move(f_debug_get_kv));
+      *rv = AttentionKVCache(std::move(n));
+    });
+
+TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced_mla")
+    .set_body([](TVMArgs args, TVMRetValue* rv) {
+      CHECK(args.size() == 39) << "Invalid number of KV cache constructor 
args.";
+      ShapeTuple cache_config = args[0];
+      ShapeTuple layer_indptr_tuple = args[1];
+      int num_groups = 1;
+      int group_id = 0;
+      if (DiscoWorker* disco_worker = ThreadLocalDiscoWorker::Get()->worker) {
+        // In the Disco worker thread
+        num_groups = disco_worker->num_groups;
+        group_id = disco_worker->worker_id / (disco_worker->num_workers / 
num_groups);
+      }
+      CHECK_EQ(layer_indptr_tuple.size(), num_groups + 1);
+      int64_t num_layers = layer_indptr_tuple[group_id + 1] - 
layer_indptr_tuple[group_id];
+      int64_t layer_id_begin_offset = layer_indptr_tuple[group_id];
+      int64_t num_qo_heads = args[2];
+      int64_t num_kv_heads = args[3];
+      int64_t qk_head_dim = args[4];
+      int64_t v_head_dim = args[5];
+      int64_t qk_rope_head_dim = args[6];
+      IntTuple attn_kinds = args[7];
+      int rope_mode = args[8];
+      double rotary_scale = args[9];
+      double rotary_theta = args[10];
+      NDArray init = args[11];
+      PackedFunc f_transpose_append = args[12];
+      PackedFunc f_transpose_append_mla = args[13];
+      PackedFunc f_attention_prefill = args[14];
+      PackedFunc f_attention_decode = args[15];
+      PackedFunc f_attention_prefill_sliding_window = args[16];
+      PackedFunc f_attention_decode_sliding_window = args[17];
+      PackedFunc f_attention_prefill_ragged = args[18];
+      Optional<PackedFunc> f_attention_prefill_ragged_begin_forward = NullOpt;
+      Optional<PackedFunc> f_attention_prefill_ragged_end_forward = NullOpt;
+      Optional<PackedFunc> f_attention_prefill_begin_forward = NullOpt;
+      Optional<PackedFunc> f_attention_prefill_end_forward = NullOpt;
+      Optional<PackedFunc> f_attention_decode_begin_forward = NullOpt;
+      Optional<PackedFunc> f_attention_decode_end_forward = NullOpt;
+      PackedFunc f_mla_prefill = args[25];
+      PackedFunc f_mla_decode = args[26];
+      PackedFunc f_mla_prefill_ragged_normal = args[27];
+      PackedFunc f_mla_prefill_ragged_absorbed = args[28];
+      PackedFunc f_merge_inplace = args[29];
+      PackedFunc f_split_rotary = args[30];
+      PackedFunc f_separate_rotary = args[31];
+      PackedFunc f_copy_single_page = args[32];
+      Optional<PackedFunc> f_debug_get_kv = args[33];
+      PackedFunc f_compact_copy = args[34];
+      PackedFunc f_attention_prefill_with_tree_mask = args[35];
+      PackedFunc f_attention_prefill_with_tree_mask_paged_kv = args[36];
+      Optional<NDArray> rope_ext_factors = NullOpt;
+      bool enable_kv_transfer = false;
+
+      if (args[37].IsObjectRef<NDArray>()) {
+        rope_ext_factors = args[37].AsObjectRef<NDArray>();
+      }
+      enable_kv_transfer = args[38];
+
+      auto f_convert_optional_packed_func = [&args](int arg_idx) -> 
Optional<PackedFunc> {
+        if (args[arg_idx].IsObjectRef<PackedFunc>()) {
+          return args[arg_idx].AsObjectRef<PackedFunc>();
+        }
+        return NullOpt;
+      };
+      f_attention_prefill_ragged_begin_forward = 
f_convert_optional_packed_func(19);
+      f_attention_prefill_ragged_end_forward = 
f_convert_optional_packed_func(20);
+      f_attention_prefill_begin_forward = f_convert_optional_packed_func(21);
+      f_attention_prefill_end_forward = f_convert_optional_packed_func(22);
+      f_attention_decode_begin_forward = f_convert_optional_packed_func(23);
+      f_attention_decode_end_forward = f_convert_optional_packed_func(24);
+
+      std::vector<AttnKind> attn_kinds_vec;
+      attn_kinds_vec.reserve(attn_kinds.size());
+      for (int64_t attn_kind : attn_kinds) {
+        attn_kinds_vec.push_back(static_cast<AttnKind>(attn_kind));
+      }
+
+      CHECK_EQ(cache_config.size(), 5);
+      int64_t reserved_num_seqs = cache_config[0];
+      int64_t total_token_capacity = cache_config[1];
+      int64_t prefill_chunk_size = cache_config[2];
+      int64_t page_size = cache_config[3];
+      bool support_sliding_window = cache_config[4];
+      int64_t num_total_pages = (total_token_capacity + page_size - 1) / 
page_size + 1;
+      if (support_sliding_window) {
+        // When sliding window is enabled, each sequence may use two more 
pages at most.
+        num_total_pages += reserved_num_seqs * 2;
+      }
+      // NOTE: We will remove this legacy construction after finishing the 
transition phase.
+      // Some `PackedFunc()` here are placeholders that will be filled.
+      ObjectPtr<PagedAttentionKVCacheObj> n = 
make_object<PagedAttentionKVCacheObj>(
+          page_size, num_layers, layer_id_begin_offset, num_qo_heads, 
num_kv_heads, qk_head_dim,
+          v_head_dim, qk_rope_head_dim, attn_kinds_vec, reserved_num_seqs, 
num_total_pages,
+          prefill_chunk_size, support_sliding_window, RoPEMode(rope_mode), 
rotary_scale,
+          rotary_theta,
+          std::move(rope_ext_factors),                    //
+          enable_kv_transfer, init->dtype, init->device,  //
+          std::move(f_transpose_append), std::move(f_transpose_append_mla),
+          std::move(f_compact_copy), std::move(f_attention_prefill), 
std::move(f_attention_decode),
+          std::move(f_attention_prefill_sliding_window),
+          std::move(f_attention_decode_sliding_window), 
std::move(f_attention_prefill_ragged),
+          std::move(f_attention_prefill_with_tree_mask),           //
+          std::move(f_attention_prefill_with_tree_mask_paged_kv),  //
+          std::move(f_attention_prefill_ragged_begin_forward),
+          std::move(f_attention_prefill_ragged_end_forward),
+          std::move(f_attention_prefill_begin_forward), 
std::move(f_attention_prefill_end_forward),
+          std::move(f_attention_decode_begin_forward), 
std::move(f_attention_decode_end_forward),
+          std::move(f_mla_prefill), std::move(f_mla_decode), 
std::move(f_mla_prefill_ragged_normal),
+          std::move(f_mla_prefill_ragged_absorbed), std::move(f_merge_inplace),
+          std::move(f_split_rotary), std::move(f_separate_rotary), 
std::move(f_copy_single_page),
           std::move(f_debug_get_kv));
       *rv = AttentionKVCache(std::move(n));
     });


Reply via email to