This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 31f4721596 [Runtime] Support PagedKVCache with tree attention (#17049)
31f4721596 is described below

commit 31f47215965b3a4d58a0ee1f450965a43ce2fcd0
Author: Ruihang Lai <ruiha...@cs.cmu.edu>
AuthorDate: Sat Jun 1 07:01:56 2024 -0400

    [Runtime] Support PagedKVCache with tree attention (#17049)
    
    * [Runtime] Support PagedKVCache with tree attention
    
    This PR introduces the tree attention to PagedKVCache. With this
    feature, now the KV cache is ready for tree attention cases such as
    speculative decoding trees.
    
    This PR adds tree attention tests to test the correctness.
    
    The changes in this PR to KVState interface are backward compatible.
    
    * Update kv_state.cc
    
    * Update kv_state.cc
    
    ---------
    
    Co-authored-by: Tianqi Chen <tqc...@users.noreply.github.com>
---
 src/runtime/relax_vm/kv_state.cc                   |  15 +-
 src/runtime/relax_vm/kv_state.h                    |  15 +-
 src/runtime/relax_vm/paged_kv_cache.cc             | 657 +++++++++++++++++----
 src/runtime/relax_vm/rnn_state.cc                  |  16 +-
 ...runtime_builtin_paged_attention_kv_cache_tir.py | 561 +++++++++++++++++-
 5 files changed, 1149 insertions(+), 115 deletions(-)

diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc
index b1572bf409..b730a4eb07 100644
--- a/src/runtime/relax_vm/kv_state.cc
+++ b/src/runtime/relax_vm/kv_state.cc
@@ -40,13 +40,26 @@ TVM_REGISTER_GLOBAL("vm.builtin.kv_state_fork_sequence")
     .set_body_method<KVState>(&KVStateObj::ForkSequence);
 
TVM_REGISTER_GLOBAL("vm.builtin.kv_state_popn").set_body_method<KVState>(&KVStateObj::PopN);
 TVM_REGISTER_GLOBAL("vm.builtin.kv_state_begin_forward")
-    .set_body_method<KVState>(&KVStateObj::BeginForward);
+    .set_body([](TVMArgs args, TVMRetValue* rv) {
+      CHECK(args.size() == 3 || args.size() == 4)
+          << "KVState BeginForward only accepts 3 or 4 arguments";
+      KVState kv_state = args[0];
+      IntTuple seq_ids = args[1];
+      IntTuple append_lengths = args[2];
+      Optional<IntTuple> token_tree_parent_ptr{nullptr};
+      if (args.size() == 4) {
+        token_tree_parent_ptr = args[3].operator Optional<IntTuple>();
+      }
+      kv_state->BeginForward(seq_ids, append_lengths, token_tree_parent_ptr);
+    });
 TVM_REGISTER_GLOBAL("vm.builtin.kv_state_end_forward")
     .set_body_method<KVState>(&KVStateObj::EndForward);
 
 // Attention KV Cache methods
 
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_enable_sliding_window_for_seq")
     
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::EnableSlidingWindowForSeq);
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes")
+    
.set_body_method<AttentionKVCache>(&AttentionKVCacheObj::CommitAcceptedTokenTreeNodes);
 TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_empty")
     .set_body_method<AttentionKVCache>(&AttentionKVCacheObj::Empty);
 TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_get_num_available_pages")
diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h
index 12a18ba895..8de560f122 100644
--- a/src/runtime/relax_vm/kv_state.h
+++ b/src/runtime/relax_vm/kv_state.h
@@ -89,8 +89,12 @@ class KVStateObj : public Object {
    * in the model forward function.
    * \param seq_ids The ids of the sequence to run in the incoming model 
forward.
    * \param append_lengths The sequence lengths to run forward for for each 
sequence.
+   * \param token_tree_parent_ptr The parent idx array of the token trees. Its 
length
+   * is the sum of "append_lengths". Nullptr means the token tree of each 
sequence
+   * is a chain.
    */
-  virtual void BeginForward(const IntTuple& seq_ids, const IntTuple& 
append_lengths) = 0;
+  virtual void BeginForward(const IntTuple& seq_ids, const IntTuple& 
append_lengths,
+                            const Optional<IntTuple>& token_tree_parent_ptr = 
NullOpt) = 0;
 
   /*!
    * \brief Mark the start of the forward function.
@@ -142,6 +146,15 @@ class AttentionKVCacheObj : public KVStateObj {
   virtual void EnableSlidingWindowForSeq(int64_t seq_id, int32_t 
sliding_window_size,
                                          int32_t attn_sink_size) = 0;
 
+  /*!
+   * \brief Committed the accepted token tree nodes to KV cache.
+   * The commit will update the KV cache, by compacting the KV data and discard
+   * the KV data of rejected tokens.
+   * This is a mandatory step when the BeginForward is given with a token tree.
+   * \param leaf_indices The leaf token tree node index of each sequence.
+   */
+  virtual void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) = 0;
+
   /************** Attention **************/
 
   /*!
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 4ab0f3f0c6..a5b970e817 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -26,6 +26,8 @@
 #include <tvm/runtime/ndarray.h>
 #include <tvm/runtime/registry.h>
 
+#include <algorithm>
+#include <numeric>
 #include <unordered_map>
 #include <utility>
 #include <vector>
@@ -52,6 +54,8 @@ namespace relax_vm {
  * prefixes) in paged KV cache.
  */
 constexpr const int kPagedKVCacheMaxBlockDepth = 5;
+/*! \brief The maximum tree size of a single sequence in tree attention. */
+constexpr const int kTreeAttnMaxTreeSize = 256;
 /*! \brief The 8MB workspace size for attention auxiliary data. */
 constexpr const int kAttnWorkspaceByte = 8 * 1024 * 1024;
 /*! \brief The id of the temporary logical page, which is useful for sliding 
window. */
@@ -250,14 +254,14 @@ class HostMemoryVector {
  * This class manages all the int32 auxiliary data on GPU device, such as
  * page table, position arrays, etc..
  *
- * The core functions of this class is `CopyXXXAsync` and `CommitCopy`.
+ * The core functions of this class is `CopyXXXAsync` and 
`CommitAttnAuxDataCopy`.
  * `CopyXXXAsync` takes the input data on CPU host, and copy the input data
  * to GPU in an asynchronous way, and returns the NDArray view of the data
  * on GPU device.
  *
  * Being asynchronous here means the `CopyXXXAsync` function may not perform
  * data copy from CPU to GPU at the time of being called. Therefore, the
- * returned NDArray view may have wrong result, until `CommitCopy` is
+ * returned NDArray view may have wrong result, until `CommitAttnAuxDataCopy` 
is
  * explicitly invoked and the data copy stream is synchronized.
  *
  * We design this manager class in order to reduce the data copy overhead.
@@ -274,8 +278,8 @@ class PagedKVCacheAuxDataManager {
   }
 
   virtual ~PagedKVCacheAuxDataManager() = default;
-  /*! \brief Reset the status of copy manager. */
-  virtual void ResetCopy() = 0;
+  /*! \brief Reset the attention auxiliary data status of copy manager. */
+  virtual void ResetAttnAuxDataCopy() = 0;
   /*! \brief Copy the indptr array of append lengths after coalescing. (see 
GetChunkedBlockIds) */
   virtual NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) 
= 0;
   /*! \brief Copy the indptr array of page table. */
@@ -315,8 +319,22 @@ class PagedKVCacheAuxDataManager {
    * appending new K/V data.
    */
   virtual NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) = 0;
-  /*! \brief Commit all the copy operations since the last commit. */
-  virtual void CommitCopy() = 0;
+  /*! \brief Copy the tree attention mask. */
+  virtual NDArray CopyTreeAttnMaskAsync(HostMemoryVector* data) = 0;
+  /*! \brief Copy the mn indptr of the tree attention mask. */
+  virtual NDArray CopyTreeAttnMNIndptrAsync(HostMemoryVector* data) = 0;
+  /*! \brief Commit all the attention auxiliary data copy operations since the 
last commit. */
+  virtual void CommitAttnAuxDataCopy() = 0;
+
+  /*! \brief Reset the compact KV auxiliary data status of copy manager. */
+  virtual void ResetCompactKVAuxDataCopy() = 0;
+  /*! \brief Copy the length indptr array of KV data copy for each sequence. */
+  virtual NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) = 0;
+  /*! \brief Copy the src/dst position arrays for each sequence. */
+  virtual NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* 
src_data,
+                                                      HostMemoryVector* 
dst_data) = 0;
+  /*! \brief Commit all the compact KV auxiliary data copy operations since 
the last commit. */
+  virtual void CommitCompactKVAuxDataCopy() = 0;
 
  protected:
   /*! \brief The dtype of the auxiliary data. It is expected to be int32. */
@@ -356,10 +374,18 @@ class PlainPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
     k_ragged_rope_pos_offset_device_ = NDArray::Empty({reserved_num_seqs}, 
dtype_aux_, device);
     q_rope_position_map_device_ = NDArray::Empty({prefill_chunk_size}, 
dtype_aux_, device);
     append_position_map_device_ = NDArray::Empty({prefill_chunk_size}, 
dtype_aux_, device);
+    tree_attn_mask_device_ = NDArray::Empty(
+        {kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * reserved_num_seqs}, 
dtype_aux_, device);
+    tree_attn_mn_indptr_device_ = NDArray::Empty({reserved_num_seqs + 1}, 
dtype_aux_, device);
+
+    commit_copy_length_indptr_device_ = NDArray::Empty({reserved_num_seqs + 
1}, dtype_aux_, device);
+    commit_copy_src_dst_pos_in_page_table_device_ =
+        NDArray::Empty({2, std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, 
prefill_chunk_size)},
+                       dtype_aux_, device);
   }
 
   // The reset of the plain auxiliary data manager is no-op.
-  void ResetCopy() final {}
+  void ResetAttnAuxDataCopy() final {}
   NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final {
     NDArray view = qo_indptr_on_depths_device_[depth].CreateView(
         {static_cast<int64_t>(data->size())}, dtype_aux_);
@@ -414,6 +440,18 @@ class PlainPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
     CopyVecDataToArray(view, data->data());
     return view;
   }
+  NDArray CopyTreeAttnMaskAsync(HostMemoryVector* data) final {
+    NDArray view =
+        
tree_attn_mask_device_.CreateView({static_cast<int64_t>(data->size())}, 
dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+  NDArray CopyTreeAttnMNIndptrAsync(HostMemoryVector* data) final {
+    NDArray view =
+        
tree_attn_mn_indptr_device_.CreateView({static_cast<int64_t>(data->size())}, 
dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
 
   NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len,
                                      HostMemoryVector* sliding_window_offset,
@@ -431,7 +469,32 @@ class PlainPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
   }
 
   // The commit of the plain auxiliary data manager is no-op.
-  void CommitCopy() final {}
+  void CommitAttnAuxDataCopy() final {}
+
+  // The reset of the plain auxiliary data manager is no-op.
+  void ResetCompactKVAuxDataCopy() final {}
+
+  NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) final {
+    NDArray view = commit_copy_length_indptr_device_.CreateView(
+        {static_cast<int64_t>(data->size())}, dtype_aux_);
+    CopyVecDataToArray(view, data->data());
+    return view;
+  }
+  NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data,
+                                              HostMemoryVector* dst_data) 
final {
+    int n_elem = src_data->size();
+    ICHECK_GT(n_elem, 0);
+    NDArray view =
+        commit_copy_src_dst_pos_in_page_table_device_.CreateView({2, n_elem}, 
dtype_aux_);
+    ShapeTuple copy_shape{n_elem};
+    CopyVecDataToArray(view, src_data->data(), copy_shape);
+    CopyVecDataToArray(view, dst_data->data(), copy_shape,
+                       /*dst_elem_offset=*/n_elem);
+    return view;
+  }
+
+  // The commit of the plain auxiliary data manager is no-op.
+  void CommitCompactKVAuxDataCopy() final {}
 
  private:
   /*!
@@ -488,81 +551,136 @@ class PlainPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
   NDArray k_ragged_rope_pos_offset_device_;
   NDArray q_rope_position_map_device_;
   NDArray append_position_map_device_;
+  NDArray tree_attn_mask_device_;
+  NDArray tree_attn_mn_indptr_device_;
+  NDArray commit_copy_length_indptr_device_;
+  NDArray commit_copy_src_dst_pos_in_page_table_device_;
 };
 
 /*!
  * \brief The cached auxiliary data manager class.
  * It allocates a large on-device array to store all the auxiliary data.
  * For each `CopyXXXAsync`, it copies the input data to a local cache on host.
- * In `CommitCopy`, it copies all the data in the local cache to the device
+ * In `CommitAttnAuxDataCopy`, it copies all the data in the local cache to 
the device
  * array for a single time, and thus reduce the number of host-to-device 
copies needed.
  */
 class CachedPagedKVCacheAuxDataManager : public PagedKVCacheAuxDataManager {
  public:
   explicit CachedPagedKVCacheAuxDataManager(int64_t reserved_num_seqs, int64_t 
num_total_pages,
                                             int64_t prefill_chunk_size, 
DLDataType dtype_aux,
-                                            DLDevice device, Device 
preferred_host_device,
+                                            Device device, Device 
preferred_host_device,
                                             TVMStreamHandle copy_stream)
       : PagedKVCacheAuxDataManager(dtype_aux, device, preferred_host_device, 
copy_stream),
         elem_byte_size_((dtype_aux.bits * dtype_aux.lanes + 7) / 8),
         offset_alignment_(cuda_byte_alignment_ / elem_byte_size_) {
-    // - Calculate cache size of all the auxiliary arrays in
+    // - Calculate cache size of all the attention auxiliary arrays in
     // local cache and the large on-device array.
-    int64_t cache_size = CalculateCacheSize(reserved_num_seqs, 
num_total_pages, prefill_chunk_size);
+    int64_t attn_aux_data_cache_size =
+        CalculateAttnAuxDataCacheSize(reserved_num_seqs, num_total_pages, 
prefill_chunk_size);
     // - Initialize the host auxiliary data buffer.
-    merged_aux_data_host_ = HostMemoryVector(cache_size, dtype_aux, 
preferred_host_device);
+    merged_attn_aux_data_host_ =
+        HostMemoryVector(attn_aux_data_cache_size, dtype_aux, 
preferred_host_device);
     // - Initialize the device auxiliary data buffer.
-    memory::Allocator* allocator =
-        memory::MemoryManager::GetOrCreateAllocator(device, 
memory::AllocatorType::kNaive);
-    ICHECK_NOTNULL(allocator);
-    merged_aux_data_device_ =
-        memory::Storage(allocator->Alloc(device, {cache_size}, dtype_aux), 
allocator);
+    merged_attn_aux_data_device_ = NDArray::Empty({attn_aux_data_cache_size}, 
dtype_aux, device);
+
+    // - Calculate cache size of all the compact KV auxiliary arrays in
+    // local cache and the large on-device array.
+    int64_t compact_kv_aux_data_cache_size =
+        CalculateCompactKVAuxDataCacheSize(reserved_num_seqs, 
prefill_chunk_size);
+    // - Initialize the host auxiliary data buffer.
+    merged_compact_kv_aux_data_host_ =
+        HostMemoryVector(compact_kv_aux_data_cache_size, dtype_aux, 
preferred_host_device);
+    merged_compact_kv_aux_data_device_ =
+        NDArray::Empty({compact_kv_aux_data_cache_size}, dtype_aux, device);
   }
 
-  void ResetCopy() final { copy_offset_ = 0; }
+  void ResetAttnAuxDataCopy() final { attn_aux_data_copy_offset_ = 0; }
   NDArray CopyQOIndptrOnDepthAsync(HostMemoryVector* data, int depth) final {
-    return CopyVecToCache(data);
+    return CopyAttnAuxVecToCache(data);
   }
   NDArray CopyPageIndptrOnDepthAsync(HostMemoryVector* data, int depth) final {
-    return CopyVecToCache(data);
+    return CopyAttnAuxVecToCache(data);
   }
   NDArray CopyPageIndicesOnDepthAsync(HostMemoryVector* data, int depth) final 
{
-    return CopyVecToCache(data);
+    return CopyAttnAuxVecToCache(data);
   }
   NDArray CopyLastPageLenOnDepthAsync(HostMemoryVector* data, int depth) final 
{
-    return CopyVecToCache(data);
+    return CopyAttnAuxVecToCache(data);
   }
   NDArray CopyKRoPEPosOffsetOnDepthAsync(HostMemoryVector* data, int depth) 
final {
-    return CopyVecToCache(data);
+    return CopyAttnAuxVecToCache(data);
   }
   NDArray CopyCurAppendLengthIndptrAsync(HostMemoryVector* data) final {
-    return CopyVecToCache(data);
+    return CopyAttnAuxVecToCache(data);
   }
   NDArray CopyKRaggedRoPEPosOffsetAsync(HostMemoryVector* data) final {
-    return CopyVecToCache(data);
+    return CopyAttnAuxVecToCache(data);
+  }
+  NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) final { return 
CopyAttnAuxVecToCache(data); }
+  NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final {
+    return CopyAttnAuxVecToCache(data);
+  }
+  NDArray CopyTreeAttnMaskAsync(HostMemoryVector* data) final {
+    return CopyAttnAuxVecToCache(data);
+  }
+  NDArray CopyTreeAttnMNIndptrAsync(HostMemoryVector* data) final {
+    return CopyAttnAuxVecToCache(data);
   }
-  NDArray CopyQRoPEPosMapAsync(HostMemoryVector* data) final { return 
CopyVecToCache(data); }
-  NDArray CopyAppendPositionMapAsync(HostMemoryVector* data) final { return 
CopyVecToCache(data); }
   NDArray CopyLengthInfoOnDepthAsync(HostMemoryVector* last_page_len,
                                      HostMemoryVector* sliding_window_offset,
                                      HostMemoryVector* sink_size, int depth) 
final {
     int64_t n_elem = last_page_len->size();
-    std::memcpy(merged_aux_data_host_.data() + copy_offset_, 
last_page_len->data(),
-                n_elem * elem_byte_size_);
-    std::memcpy(merged_aux_data_host_.data() + copy_offset_ + n_elem, 
sliding_window_offset->data(),
-                n_elem * elem_byte_size_);
-    std::memcpy(merged_aux_data_host_.data() + copy_offset_ + 2 * n_elem, 
sink_size->data(),
-                n_elem * elem_byte_size_);
-    NDArray view = merged_aux_data_device_->AllocNDArray(copy_offset_ * 
elem_byte_size_,
-                                                         {3, n_elem}, 
dtype_aux_);
-    copy_offset_ += CeilDivElemAlignment(3 * n_elem);
+    std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_,
+                last_page_len->data(), n_elem * elem_byte_size_);
+    std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_ 
+ n_elem,
+                sliding_window_offset->data(), n_elem * elem_byte_size_);
+    std::memcpy(merged_attn_aux_data_host_.data() + attn_aux_data_copy_offset_ 
+ 2 * n_elem,
+                sink_size->data(), n_elem * elem_byte_size_);
+    NDArray view = merged_attn_aux_data_device_.CreateView(
+        {3, n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_);
+    attn_aux_data_copy_offset_ += CeilDivElemAlignment(3 * n_elem);
+    return view;
+  }
+
+  void CommitAttnAuxDataCopy() final {
+    std::vector<int64_t> copy_shape{attn_aux_data_copy_offset_};
+    DLTensor copy_dst;
+    copy_dst.data = merged_attn_aux_data_device_->data;
+    copy_dst.device = device_;
+    copy_dst.ndim = 1;
+    copy_dst.dtype = dtype_aux_;
+    copy_dst.shape = copy_shape.data();
+    copy_dst.strides = nullptr;
+    copy_dst.byte_offset = 0;
+
+    DLTensor copy_src = copy_dst;
+    copy_src.data = merged_attn_aux_data_host_.data();
+    copy_src.device = Device{kDLCPU, 0};
+    NDArray::CopyFromTo(&copy_src, &copy_dst, copy_stream_);
+  }
+
+  void ResetCompactKVAuxDataCopy() final { compact_kv_aux_data_copy_offset_ = 
0; }
+
+  NDArray CopyCommitLengthIndptrAsync(HostMemoryVector* data) final {
+    return CopyCompactKVAuxVecToCache(data);
+  }
+  NDArray CopyCommitSrcDstPosInPageTableAsync(HostMemoryVector* src_data,
+                                              HostMemoryVector* dst_data) 
final {
+    int64_t n_elem = src_data->size();
+    std::memcpy(merged_compact_kv_aux_data_host_.data() + 
compact_kv_aux_data_copy_offset_,
+                src_data->data(), n_elem * elem_byte_size_);
+    std::memcpy(merged_compact_kv_aux_data_host_.data() + 
compact_kv_aux_data_copy_offset_ + n_elem,
+                dst_data->data(), n_elem * elem_byte_size_);
+    NDArray view = merged_compact_kv_aux_data_device_.CreateView(
+        {2, n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ * 
elem_byte_size_);
+    compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(2 * n_elem);
     return view;
   }
 
-  void CommitCopy() final {
-    std::vector<int64_t> copy_shape{copy_offset_};
+  void CommitCompactKVAuxDataCopy() final {
+    std::vector<int64_t> copy_shape{compact_kv_aux_data_copy_offset_};
     DLTensor copy_dst;
-    copy_dst.data = merged_aux_data_device_->buffer.data;
+    copy_dst.data = merged_compact_kv_aux_data_device_->data;
     copy_dst.device = device_;
     copy_dst.ndim = 1;
     copy_dst.dtype = dtype_aux_;
@@ -571,7 +689,7 @@ class CachedPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
     copy_dst.byte_offset = 0;
 
     DLTensor copy_src = copy_dst;
-    copy_src.data = merged_aux_data_host_.data();
+    copy_src.data = merged_compact_kv_aux_data_host_.data();
     copy_src.device = Device{kDLCPU, 0};
     NDArray::CopyFromTo(&copy_src, &copy_dst, copy_stream_);
   }
@@ -581,8 +699,8 @@ class CachedPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
    * \brief Calculate the start element offsets of the auxiliary arrays in the 
local cache.
    * \return Return the local cache size (total number of elements in the 
local cache).
    */
-  int64_t CalculateCacheSize(int64_t reserved_num_seqs, int64_t 
num_total_pages,
-                             int64_t prefill_chunk_size) {
+  int64_t CalculateAttnAuxDataCacheSize(int64_t reserved_num_seqs, int64_t 
num_total_pages,
+                                        int64_t prefill_chunk_size) {
     int64_t cache_size = 0;
     // - Array size of the arrays that every depth has.
     // Corresponding to the following arrays respectively
@@ -604,10 +722,28 @@ class CachedPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
     //  - k_ragged_rope_pos_offset
     //  - q_rope_position_map
     //  - append_position_map
+    //  - tree_attn_mask
+    //  - tree_attn_mn_indptr
     cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
     cache_size += CeilDivElemAlignment(reserved_num_seqs);
     cache_size += CeilDivElemAlignment(prefill_chunk_size);
     cache_size += CeilDivElemAlignment(prefill_chunk_size);
+    cache_size +=
+        CeilDivElemAlignment(kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * 
reserved_num_seqs);
+    cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
+
+    return cache_size;
+  }
+
+  int64_t CalculateCompactKVAuxDataCacheSize(int64_t reserved_num_seqs,
+                                             int64_t prefill_chunk_size) {
+    int64_t cache_size = 0;
+    // Corresponding to the following arrays respectively
+    //  - commit_copy_length_indptr
+    //  - commit_copy_src_dst_pos_in_page_table
+    cache_size += CeilDivElemAlignment(reserved_num_seqs + 1);
+    cache_size += CeilDivElemAlignment(
+        2 * std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, 
prefill_chunk_size));
 
     return cache_size;
   }
@@ -616,13 +752,23 @@ class CachedPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
    * \brief Copy the input data to the cache at the given offset.
    * And return the NDArray view of the cache starting at the offset.
    */
-  NDArray CopyVecToCache(HostMemoryVector* data) {
+  NDArray CopyAttnAuxVecToCache(HostMemoryVector* data) {
     int64_t n_elem = data->size();
-    std::memcpy(merged_aux_data_host_.data() + copy_offset_, data->data(),
+    std::memcpy(merged_attn_aux_data_host_.data() + 
attn_aux_data_copy_offset_, data->data(),
                 n_elem * elem_byte_size_);
-    NDArray view =
-        merged_aux_data_device_->AllocNDArray(copy_offset_ * elem_byte_size_, 
{n_elem}, dtype_aux_);
-    copy_offset_ += CeilDivElemAlignment(n_elem);
+    NDArray view = merged_attn_aux_data_device_.CreateView(
+        {n_elem}, dtype_aux_, attn_aux_data_copy_offset_ * elem_byte_size_);
+    attn_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem);
+    return view;
+  }
+
+  NDArray CopyCompactKVAuxVecToCache(HostMemoryVector* data) {
+    int64_t n_elem = data->size();
+    std::memcpy(merged_compact_kv_aux_data_host_.data() + 
compact_kv_aux_data_copy_offset_,
+                data->data(), n_elem * elem_byte_size_);
+    NDArray view = merged_compact_kv_aux_data_device_.CreateView(
+        {n_elem}, dtype_aux_, compact_kv_aux_data_copy_offset_ * 
elem_byte_size_);
+    compact_kv_aux_data_copy_offset_ += CeilDivElemAlignment(n_elem);
     return view;
   }
 
@@ -635,9 +781,12 @@ class CachedPagedKVCacheAuxDataManager : public 
PagedKVCacheAuxDataManager {
   const int64_t elem_byte_size_;
   const int64_t offset_alignment_;
 
-  int64_t copy_offset_ = 0;
-  HostMemoryVector merged_aux_data_host_;
-  memory::Storage merged_aux_data_device_;
+  int64_t attn_aux_data_copy_offset_ = 0;
+  int64_t compact_kv_aux_data_copy_offset_ = 0;
+  HostMemoryVector merged_attn_aux_data_host_;
+  HostMemoryVector merged_compact_kv_aux_data_host_;
+  NDArray merged_attn_aux_data_device_;
+  NDArray merged_compact_kv_aux_data_device_;
 };
 
 /*!
@@ -726,8 +875,24 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   bool dirty_aux_data_device_ = false;
   /*! \brief The batch size of the current round of forwarding. */
   int64_t cur_batch_size_;
+  /*! \brief The ids of the sequences in the current round of forwarding. */
+  IntTuple cur_seq_ids_;
   /*! \brief The append lengths of the sequences in the current round of 
forwarding. */
   IntTuple cur_append_lengths_;
+  /*! \brief The token tree parent array of the sequences in the current round 
of forwarding. */
+  IntTuple cur_token_tree_parent_ptr_{nullptr};
+  /*! \brief The depth of each node in the token tree, for the sequences in 
the current batch. */
+  std::vector<std::vector<int32_t>> cur_token_tree_node_depths_;
+  /*! \brief Whether the current batch of sequences are token chains (not 
token trees). */
+  bool is_chain_;
+  /*! \brief Number of fork depth in the current round of forward. */
+  int num_depths_;
+  /*! \brief Whether to compute attention after appending KV into cache or 
not. */
+  bool append_before_attn_;
+  /*! \brief Whether to use decode kernel for each depth. (see 
GetChunkedBlockIds) */
+  std::vector<bool> use_decode_kernel_;
+  /*! \brief Whether the attention request is a decode request, set in 
BeginForwardFunction. */
+  bool is_decode_request_;
   /*! \brief The auxiliary data manager for attention. */
   std::unique_ptr<PagedKVCacheAuxDataManager> aux_data_manager_;
 
@@ -755,6 +920,11 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   HostMemoryVector q_rope_position_map_host_;
   HostMemoryVector append_position_map_host_;
   HostMemoryVector cur_append_lengths_indptr_host_;
+  HostMemoryVector tree_attn_mask_host_;
+  HostMemoryVector tree_attn_mn_indptr_host_;
+  HostMemoryVector commit_copy_length_indptr_host_;
+  HostMemoryVector commit_copy_src_pos_in_page_table_host_;
+  HostMemoryVector commit_copy_dst_pos_in_page_table_host_;
 
   //-------------------------------------------
   // For efficient memory management, the actual sizes of the arrays
@@ -767,6 +937,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj 
{
   NDArray k_ragged_rope_pos_offset_view_;
   NDArray q_rope_position_map_view_;
   NDArray append_position_map_view_;
+  NDArray tree_attn_mask_view_;
+  NDArray tree_attn_mn_indptr_view_;
   NDArray temp_attn_output_view_;
   NDArray temp_attn_scores_view_;
   NDArray merged_attn_scores_view_;
@@ -777,11 +949,13 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   std::vector<NDArray> k_rope_pos_offset_view_;
 
   PackedFunc f_transpose_append_;
+  PackedFunc f_compact_copy_;
   PackedFunc f_attention_prefill_;
   PackedFunc f_attention_decode_;
   PackedFunc f_attention_prefill_sliding_window_;
   PackedFunc f_attention_decode_sliding_window_;
   PackedFunc f_attention_prefill_ragged_;
+  PackedFunc f_attention_prefill_with_tree_mask_;
   Optional<PackedFunc> f_attention_prefill_ragged_begin_forward_;
   Optional<PackedFunc> f_attention_prefill_ragged_end_forward_;
   Optional<PackedFunc> f_attention_prefill_begin_forward_;
@@ -793,16 +967,8 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   PackedFunc f_copy_single_page_;
   Optional<PackedFunc> f_debug_get_kv_;
 
-  /*! \brief Number of fork depth in the current round of forward. */
-  int num_depths_;
-  /*! \brief Whether to compute attention after appending KV into cache or 
not. */
-  bool append_before_attn_;
-  /*! \brief Whether to use decode kernel for each depth. (see 
GetChunkedBlockIds) */
-  std::vector<bool> use_decode_kernel_;
-  /*! \brief Whether the attention request is a decode request, set in 
BeginForwardFunction. */
-  bool is_decode_request_;
   /*! \brief The device this PagedKVCache runs on. */
-  DLDevice device_;
+  Device device_;
   /*! \brief The device stream for the default computation operations. */
   TVMStreamHandle compute_stream_ = nullptr;
   /*! \brief The device stream for copying auxiliary data structure to GPU. */
@@ -815,10 +981,10 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       int64_t num_layers, int64_t num_qo_heads, int64_t num_kv_heads, int64_t 
head_dim,
       int64_t reserved_num_seqs, int64_t num_total_pages, int64_t 
prefill_chunk_size,
       bool support_sliding_window, RoPEMode rope_mode, double rotary_scale, 
double rotary_theta,
-      DLDataType dtype, DLDevice device, PackedFunc f_transpose_append,
+      DLDataType dtype, Device device, PackedFunc f_transpose_append, 
PackedFunc f_compact_copy,
       PackedFunc f_attention_prefill, PackedFunc f_attention_decode,
       PackedFunc f_attention_prefill_sliding_window, PackedFunc 
f_attention_decode_sliding_window,
-      PackedFunc f_attention_prefill_ragged,
+      PackedFunc f_attention_prefill_ragged, PackedFunc 
f_attention_prefill_with_tree_mask,
       Optional<PackedFunc> f_attention_prefill_ragged_begin_forward,
       Optional<PackedFunc> f_attention_prefill_ragged_end_forward,
       Optional<PackedFunc> f_attention_prefill_begin_forward,
@@ -839,11 +1005,13 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         rotary_scale_(rotary_scale),
         rotary_theta_(rotary_theta),
         f_transpose_append_(std::move(f_transpose_append)),
+        f_compact_copy_(std::move(f_compact_copy)),
         f_attention_prefill_(std::move(f_attention_prefill)),
         f_attention_decode_(std::move(f_attention_decode)),
         
f_attention_prefill_sliding_window_(std::move(f_attention_prefill_sliding_window)),
         
f_attention_decode_sliding_window_(std::move(f_attention_decode_sliding_window)),
         f_attention_prefill_ragged_(std::move(f_attention_prefill_ragged)),
+        
f_attention_prefill_with_tree_mask_(std::move(f_attention_prefill_with_tree_mask)),
         f_attention_prefill_ragged_begin_forward_(
             std::move(f_attention_prefill_ragged_begin_forward)),
         
f_attention_prefill_ragged_end_forward_(std::move(f_attention_prefill_ragged_end_forward)),
@@ -887,6 +1055,19 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         HostMemoryVector(prefill_chunk_size, dtype_aux_, 
preferred_host_device);
     cur_append_lengths_indptr_host_ =
         HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, 
preferred_host_device);
+    tree_attn_mask_host_ =
+        HostMemoryVector(kTreeAttnMaxTreeSize * kTreeAttnMaxTreeSize * 
reserved_num_seqs,
+                         dtype_aux_, preferred_host_device);
+    tree_attn_mn_indptr_host_ =
+        HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, 
preferred_host_device);
+    commit_copy_length_indptr_host_ =
+        HostMemoryVector(reserved_num_seqs + 1, dtype_aux_, 
preferred_host_device);
+    commit_copy_src_pos_in_page_table_host_ =
+        HostMemoryVector(std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, 
prefill_chunk_size),
+                         dtype_aux_, preferred_host_device);
+    commit_copy_dst_pos_in_page_table_host_ =
+        HostMemoryVector(std::min(kTreeAttnMaxTreeSize * reserved_num_seqs, 
prefill_chunk_size),
+                         dtype_aux_, preferred_host_device);
 
     for (int d = 0; d < kPagedKVCacheMaxBlockDepth; ++d) {
       temp_attn_workspace_.push_back(
@@ -1108,6 +1289,42 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     }
   }
 
+  void CompactKVCopy() {
+    int total_copy_length = commit_copy_length_indptr_host_.back();
+    ICHECK_GE(total_copy_length, 0);
+    if (total_copy_length == 0) {
+      return;
+    }
+
+    // Copy indptr/src/dst arrays to GPU.
+    aux_data_manager_->ResetCompactKVAuxDataCopy();
+    NDArray commit_copy_length_indptr_view =
+        
aux_data_manager_->CopyCommitLengthIndptrAsync(&commit_copy_length_indptr_host_);
+    NDArray commit_copy_src_dst_pos_in_page_table_view =
+        aux_data_manager_->CopyCommitSrcDstPosInPageTableAsync(
+            &commit_copy_src_pos_in_page_table_host_, 
&commit_copy_dst_pos_in_page_table_host_);
+    aux_data_manager_->CommitCompactKVAuxDataCopy();
+
+    // Invoke the copy kernel on copy stream.
+    if (copy_stream_ != compute_stream_) {
+      // Set the copy stream for copy.
+      DeviceAPI::Get(device_)->SetStream(device_, copy_stream_);
+    }
+    ICHECK(f_compact_copy_.defined()) << "Function \"f_compact_copy\" is not 
defined.";
+    for (int layer = 0; layer < num_layers_; ++layer) {
+      f_compact_copy_(pages_[layer], commit_copy_length_indptr_view,
+                      commit_copy_src_dst_pos_in_page_table_view, 
cur_batch_size_);
+    }
+    if (copy_stream_ != compute_stream_) {
+      // Set the compute stream back.
+      DeviceAPI::Get(device_)->SetStream(device_, compute_stream_);
+    }
+
+    // Note: We do not explicitly synchronize the copy stream here.
+    // The safety is guaranteed by the synchronization pushed by the next round
+    // of BeginForward, which also copies auxiliary data structure to GPU.
+  }
+
   void EnableSlidingWindowForSeq(int64_t seq_id, int32_t sliding_window_size,
                                  int32_t attn_sink_size) final {
     CHECK(support_sliding_window_) << "The KV cache does not support sliding 
window.";
@@ -1143,6 +1360,10 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     CHECK_LE(n, it->second.seq_length)
         << "The sequence only has length " << it->second.seq_length
         << ", while the length of pop is " << n << " which exceeds the whole 
sequence length.";
+    if (n == 0) {
+      return;
+    }
+
     int32_t block_idx = it->second.last_block_idx;
     // The block should have at least one reference, which comes from the 
sequence.
     ICHECK_GE(global_block_pool_[block_idx].external_ref_cnt, 1);
@@ -1211,13 +1432,27 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
 
   /************** Attention **************/
 
-  void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths) 
final {
+  void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths,
+                    const Optional<IntTuple>& opt_token_tree_parent_ptr) final 
{
+    CHECK(!cur_token_tree_parent_ptr_.defined())
+        << "The last round of forward which involves token tree has not been 
committed. Please "
+           "call \"CommitAcceptedTreeNodes\" to commit the accepted tokens.";
+
     CHECK_EQ(seq_ids.size(), append_lengths.size())
         << "The seq_ids size (" << seq_ids.size() << ") and append_lengths 
size ("
         << append_lengths.size() << ") mismatch.";
     cur_batch_size_ = seq_ids.size();
+    cur_seq_ids_ = seq_ids;
     cur_append_lengths_ = append_lengths;
 
+    // - Check token tree validity and process the token tree.
+    is_chain_ = true;
+    tree_attn_mask_host_.clear();
+    tree_attn_mn_indptr_host_.clear();
+    if (opt_token_tree_parent_ptr.defined()) {
+      is_chain_ = ConstructTokenTreeMask(opt_token_tree_parent_ptr.value());
+    }
+
     // - Collect sequence/block/page information for attention.
     std::vector<Sequence*> sequences;
     std::vector<int32_t> last_block_length_before_append;
@@ -1322,7 +1557,9 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       int64_t append_length = append_lengths[i];
       const Block& block = global_block_pool_[sequences[i]->last_block_idx];
       for (int64_t pos = 0; pos < append_length; ++pos) {
-        q_rope_position_map_host_.push_back(k_ragged_rope_pos_offset_host_[i] 
+ pos);
+        q_rope_position_map_host_.push_back(
+            k_ragged_rope_pos_offset_host_[i] +
+            (is_chain_ ? pos : cur_token_tree_node_depths_[i][pos]));
 
         int32_t pos_in_block = block.seq_length - append_length + pos;
         if (last_block_length_before_append[i] + pos < block.sink_length) {
@@ -1412,6 +1649,81 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     }
   }
 
+  void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) final {
+    CHECK_NE(cur_batch_size_, -1)
+        << "Cannot commit accepted token tree nodes since BeginForward is not 
invoked.";
+    CHECK_EQ(leaf_indices.size(), cur_batch_size_)
+        << "The number of input leaf indices does not equal to the current 
batch size.";
+
+    for (int i = 0; i < cur_batch_size_; ++i) {
+      CHECK_GE(leaf_indices[i], 0)
+          << "Invalid tree index " << leaf_indices[i] << " which is negative";
+      CHECK_LT(leaf_indices[i], cur_append_lengths_[i])
+          << "Invalid tree index " << leaf_indices[i]
+          << " which is larger than or equals to the append length " << 
cur_append_lengths_[i]
+          << " of the sequence";
+    }
+
+    if (!is_chain_) {
+      commit_copy_length_indptr_host_.clear();
+      commit_copy_src_pos_in_page_table_host_.clear();
+      commit_copy_dst_pos_in_page_table_host_.clear();
+      commit_copy_length_indptr_host_.push_back(0);
+
+      for (int i = 0; i < cur_batch_size_; ++i) {
+        // Get the accepted node path on the token tree.
+        std::vector<int32_t> path_on_tree;
+        path_on_tree.reserve(cur_token_tree_node_depths_[i][leaf_indices[i]] + 
1);
+        int node = leaf_indices[i];
+        while (node != -1) {
+          path_on_tree.push_back(node);
+          node = cur_token_tree_parent_ptr_[cur_append_lengths_indptr_host_[i] 
+ node];
+        }
+        ICHECK_EQ(path_on_tree.size(), 
cur_token_tree_node_depths_[i][leaf_indices[i]] + 1);
+        // Get the destination array (range [0, path_length - 1)) of KV cache 
copy.
+        std::vector<int32_t> copy_dst_pos_in_seq;
+        copy_dst_pos_in_seq.resize(path_on_tree.size());
+        std::iota(copy_dst_pos_in_seq.rbegin(), copy_dst_pos_in_seq.rend(), 
/*value=*/0);
+        // Remove the positions whose KV data do not need copy.
+        while (!path_on_tree.empty() && path_on_tree.back() == 
copy_dst_pos_in_seq.back()) {
+          path_on_tree.pop_back();
+          copy_dst_pos_in_seq.pop_back();
+        }
+        // Reverse the position arrays so that they are in ascending order.
+        std::reverse(path_on_tree.begin(), path_on_tree.end());
+        std::reverse(copy_dst_pos_in_seq.begin(), copy_dst_pos_in_seq.end());
+
+        // Convert the in-sequence src/dst positions to src/dst positions in 
page table
+        // by looking up "append_position_map".
+        for (int p = 0; p < static_cast<int>(path_on_tree.size()); ++p) {
+          commit_copy_src_pos_in_page_table_host_.push_back(
+              append_position_map_host_[cur_append_lengths_indptr_host_[i] + 
path_on_tree[p]]);
+          commit_copy_dst_pos_in_page_table_host_.push_back(
+              append_position_map_host_[cur_append_lengths_indptr_host_[i] +
+                                        copy_dst_pos_in_seq[p]]);
+        }
+        
commit_copy_length_indptr_host_.push_back(commit_copy_length_indptr_host_.back()
 +
+                                                  path_on_tree.size());
+      }
+
+      // Compact the KV data for each sequence by copying KV data.
+      CompactKVCopy();
+    }
+
+    // - Update the KV cache page data structure.
+    //   Note: Function "PopN" only changes the page table structure and does 
not
+    //         change the KV cache data. Therefore, we can directly use it, 
since
+    //         we have already launched all copies.
+    for (int i = 0; i < cur_batch_size_; ++i) {
+      int64_t length_to_pop =
+          cur_append_lengths_[i] - 
cur_token_tree_node_depths_[i][leaf_indices[i]] - 1;
+      PopN(cur_seq_ids_[i], length_to_pop);
+    }
+
+    // Reset the token tree.
+    cur_token_tree_parent_ptr_ = IntTuple{nullptr};
+  }
+
   NDArray GetQueryPositions() final {
     // Sync the copy stream and the compute stream.
     ComputeStreamWaitForCopyStream();
@@ -1502,6 +1814,73 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     return block_idx;
   }
 
+  bool ConstructTokenTreeMask(const IntTuple& token_tree_parent_ptr) {
+    // We check if the token tree deteriorates to a chain,
+    // because chain cases can have simplified attention work flow.
+    bool is_chain = true;
+    cur_token_tree_parent_ptr_ = token_tree_parent_ptr;
+    cur_token_tree_node_depths_.clear();
+    cur_token_tree_node_depths_.reserve(cur_batch_size_);
+
+    int64_t sum_append_length = 0;
+    // - Construct the mn indptr array, which is the indptr of the mask size 
of each sequence.
+    tree_attn_mn_indptr_host_.push_back(0);
+    for (int64_t append_length : cur_append_lengths_) {
+      sum_append_length += append_length;
+      tree_attn_mn_indptr_host_.push_back(tree_attn_mn_indptr_host_.back() +
+                                          static_cast<int32_t>(append_length * 
append_length));
+    }
+    CHECK_EQ(token_tree_parent_ptr.size(), sum_append_length)
+        << "Invalid token tree size. The sum of \"append_lengths\" is " << 
sum_append_length
+        << " while there are " << token_tree_parent_ptr.size()
+        << " elements in \"token_tree_parent_ptr\".";
+
+    // - Construct the mask of each sequence.
+    int processed_pos = 0;
+    for (int i = 0; i < cur_batch_size_; ++i) {
+      int64_t append_length = cur_append_lengths_[i];
+      std::vector<std::vector<int32_t>> mask;
+      std::vector<int32_t> depth;
+      mask.reserve(append_length);
+      depth.reserve(append_length);
+      for (int64_t n = 0; n < append_length; ++n) {
+        CHECK_LT(token_tree_parent_ptr[processed_pos], n)
+            << "Invalid token tree. The parent of node " << n << " in tree " 
<< i << " is "
+            << token_tree_parent_ptr[processed_pos] << ", which is not smaller 
than " << n;
+        CHECK_GE(token_tree_parent_ptr[processed_pos], -1)
+            << "Invalid token tree. The parent of node " << n << " in tree " 
<< i << " is "
+            << token_tree_parent_ptr[processed_pos];
+        if (token_tree_parent_ptr[processed_pos] != n - 1) {
+          // The parent of the current node is not the last node.
+          // Therefore the tree is not a chain.
+          is_chain = false;
+        }
+
+        std::vector<int32_t> single_pos_mask;
+        if (token_tree_parent_ptr[processed_pos] != -1) {
+          // The current node has a parent in the token tree.
+          single_pos_mask = 
{mask[token_tree_parent_ptr[processed_pos]].begin(),
+                             mask[token_tree_parent_ptr[processed_pos]].end()};
+          depth.push_back(depth[token_tree_parent_ptr[processed_pos]] + 1);
+        } else {
+          // The current node is root in the token tree.
+          single_pos_mask.resize(append_length, /*value=*/0);
+          depth.push_back(0);
+        }
+        single_pos_mask[n] = 1;
+        mask.push_back(single_pos_mask);
+        for (int32_t mask_val : single_pos_mask) {
+          tree_attn_mask_host_.push_back(mask_val);
+        }
+
+        ++processed_pos;
+      }
+      cur_token_tree_node_depths_.push_back(std::move(depth));
+    }
+
+    return is_chain;
+  }
+
   /*!
    * \brief Slide the KV cache window of the given sequence when
    * it has sliding window enabled.
@@ -1766,12 +2145,27 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
           attn_score_scaling_factor);
     } else {
       // Compute appended text self-attention
-      f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, 
k_data, v_data,
-                                  cur_append_length_indptr_view_, 
q_rope_position_map_view_,
-                                  k_ragged_rope_pos_offset_view_, output, 
merged_attn_scores_view_,
-                                  /*causal=*/1,
-                                  /*rotary_mode=*/rope_mode_ == 
RoPEMode::kInline, rotary_scale_,
-                                  rotary_theta_, attn_score_scaling_factor);
+      if (is_chain_) {
+        // If the batch does not form a tree, use raggedness prefill kernel.
+        f_attention_prefill_ragged_(q_data, cur_append_length_indptr_view_, 
k_data, v_data,
+                                    cur_append_length_indptr_view_, 
q_rope_position_map_view_,
+                                    k_ragged_rope_pos_offset_view_, output,
+                                    merged_attn_scores_view_,
+                                    /*causal=*/1,
+                                    /*rotary_mode=*/rope_mode_ == 
RoPEMode::kInline, rotary_scale_,
+                                    rotary_theta_, attn_score_scaling_factor);
+      } else {
+        // The batch requires tree attention.
+        ICHECK(tree_attn_mask_view_.defined());
+        ICHECK(tree_attn_mn_indptr_view_.defined());
+        ICHECK(f_attention_prefill_with_tree_mask_.defined())
+            << "Function \"f_attention_prefill_with_tree_mask_\" is not 
defined.";
+        f_attention_prefill_with_tree_mask_(
+            q_data, cur_append_length_indptr_view_, k_data, v_data, 
cur_append_length_indptr_view_,
+            q_rope_position_map_view_, tree_attn_mn_indptr_view_, 
tree_attn_mask_view_, output,
+            merged_attn_scores_view_, /*rotary_mode=*/rope_mode_ == 
RoPEMode::kInline,
+            rotary_scale_, rotary_theta_, attn_score_scaling_factor, 
cur_batch_size_);
+      }
 
       for (int d = 0; d < num_depths_; ++d) {
         if (page_indices_on_depths_view_[d]->shape[0] == 0) {
@@ -1840,7 +2234,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     ICHECK_EQ(total_append_length, append_position_map_host_.size());
 
     // - Reset the copy.
-    aux_data_manager_->ResetCopy();
+    aux_data_manager_->ResetAttnAuxDataCopy();
 
     // 1. q_rope_position_map
     // q_rope_position_map has to be synced first so that it has a 0 byte 
offset
@@ -1900,7 +2294,16 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     // 9. append_position_map
     append_position_map_view_ =
         
aux_data_manager_->CopyAppendPositionMapAsync(&append_position_map_host_);
-    // 10. Create view for temporary arrays for attention computation.
+    // 10. tree_attn_mask and tree_attn_mn_indptr
+    if (!is_chain_) {
+      tree_attn_mask_view_ = 
aux_data_manager_->CopyTreeAttnMaskAsync(&tree_attn_mask_host_);
+      tree_attn_mn_indptr_view_ =
+          
aux_data_manager_->CopyTreeAttnMNIndptrAsync(&tree_attn_mn_indptr_host_);
+    } else {
+      tree_attn_mask_view_ = NDArray{nullptr};
+      tree_attn_mn_indptr_view_ = NDArray{nullptr};
+    }
+    // 11. Create view for temporary arrays for attention computation.
     temp_attn_output_view_ = temp_attn_output_device_.CreateView(
         {total_append_length, num_qo_heads_, head_dim_}, 
temp_attn_output_device_->dtype);
     temp_attn_scores_view_ = temp_attn_scores_device_.CreateView(
@@ -1909,7 +2312,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
         {total_append_length, num_qo_heads_}, 
merged_attn_scores_device_->dtype);
 
     // - Commit the copy.
-    aux_data_manager_->CommitCopy();
+    aux_data_manager_->CommitAttnAuxDataCopy();
     // - Reset the dirty flag to false.
     dirty_aux_data_device_ = false;
   }
@@ -1922,21 +2325,44 @@ TVM_REGISTER_OBJECT_TYPE(PagedAttentionKVCacheObj);
 //-------------------------------------------------
 
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
-    .set_body_typed([](ShapeTuple cache_config, int64_t num_layers, int64_t 
num_qo_heads,
-                       int64_t num_kv_heads, int64_t head_dim, int rope_mode, 
double rotary_scale,
-                       double rotary_theta, NDArray init, PackedFunc 
f_transpose_append,
-                       PackedFunc f_attention_prefill, PackedFunc 
f_attention_decode,
-                       PackedFunc f_attention_prefill_sliding_window,  //
-                       PackedFunc f_attention_decode_sliding_window,
-                       PackedFunc f_attention_prefill_ragged,
-                       PackedFunc f_attention_prefill_ragged_begin_forward,
-                       PackedFunc f_attention_prefill_ragged_end_forward,
-                       PackedFunc f_attention_prefill_begin_forward,
-                       PackedFunc f_attention_prefill_end_forward,
-                       PackedFunc f_attention_decode_begin_forward,
-                       PackedFunc f_attention_decode_end_forward, PackedFunc 
f_merge_inplace,
-                       PackedFunc f_split_rotary, PackedFunc 
f_copy_single_page,
-                       Optional<PackedFunc> f_debug_get_kv) {
+    .set_body([](TVMArgs args, TVMRetValue* rv) {
+      CHECK(args.size() == 25 || args.size() == 26 || args.size() == 27)
+          << "Invalid number of KV cache constructor args.";
+      ShapeTuple cache_config = args[0];
+      int64_t num_layers = args[1];
+      int64_t num_qo_heads = args[2];
+      int64_t num_kv_heads = args[3];
+      int64_t head_dim = args[4];
+      int rope_mode = args[5];
+      double rotary_scale = args[6];
+      double rotary_theta = args[7];
+      NDArray init = args[8];
+      PackedFunc f_transpose_append = args[9];
+      PackedFunc f_attention_prefill = args[10];
+      PackedFunc f_attention_decode = args[11];
+      PackedFunc f_attention_prefill_sliding_window = args[12];
+      PackedFunc f_attention_decode_sliding_window = args[13];
+      PackedFunc f_attention_prefill_ragged = args[14];
+      PackedFunc f_attention_prefill_ragged_begin_forward = args[15];
+      PackedFunc f_attention_prefill_ragged_end_forward = args[16];
+      PackedFunc f_attention_prefill_begin_forward = args[17];
+      PackedFunc f_attention_prefill_end_forward = args[18];
+      PackedFunc f_attention_decode_begin_forward = args[19];
+      PackedFunc f_attention_decode_end_forward = args[20];
+      PackedFunc f_merge_inplace = args[21];
+      PackedFunc f_split_rotary = args[22];
+      PackedFunc f_copy_single_page = args[23];
+      Optional<PackedFunc> f_debug_get_kv = args[24];
+      PackedFunc f_compact_copy{nullptr};
+      PackedFunc f_attention_prefill_with_tree_mask{nullptr};
+
+      if (args.size() >= 26) {
+        f_compact_copy = args[25].AsObjectRef<PackedFunc>();
+      }
+      if (args.size() >= 27) {
+        f_attention_prefill_with_tree_mask = 
args[26].AsObjectRef<PackedFunc>();
+      }
+
       CHECK_EQ(cache_config.size(), 5);
       int64_t reserved_num_seqs = cache_config[0];
       int64_t total_token_capacity = cache_config[1];
@@ -1952,28 +2378,52 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create")
           page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, 
reserved_num_seqs,
           num_total_pages, prefill_chunk_size, support_sliding_window, 
RoPEMode(rope_mode),
           rotary_scale, rotary_theta, init->dtype, init->device, 
std::move(f_transpose_append),
-          std::move(f_attention_prefill), std::move(f_attention_decode),
+          std::move(f_compact_copy), std::move(f_attention_prefill), 
std::move(f_attention_decode),
           std::move(f_attention_prefill_sliding_window),
           std::move(f_attention_decode_sliding_window), 
std::move(f_attention_prefill_ragged),
+          std::move(f_attention_prefill_with_tree_mask),
           std::move(f_attention_prefill_ragged_begin_forward),
           std::move(f_attention_prefill_ragged_end_forward),
           std::move(f_attention_prefill_begin_forward), 
std::move(f_attention_prefill_end_forward),
           std::move(f_attention_decode_begin_forward), 
std::move(f_attention_decode_end_forward),
           std::move(f_merge_inplace), std::move(f_split_rotary), 
std::move(f_copy_single_page),
           std::move(f_debug_get_kv));
-      return AttentionKVCache(std::move(n));
+      *rv = AttentionKVCache(std::move(n));
     });
 
 TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
-    .set_body_typed([](ShapeTuple cache_config, int64_t num_layers, int64_t 
num_qo_heads,
-                       int64_t num_kv_heads, int64_t head_dim, int rope_mode, 
double rotary_scale,
-                       double rotary_theta, NDArray init, PackedFunc 
f_transpose_append,
-                       PackedFunc f_attention_prefill, PackedFunc 
f_attention_decode,
-                       PackedFunc f_attention_prefill_sliding_window,
-                       PackedFunc f_attention_decode_sliding_window,
-                       PackedFunc f_attention_prefill_ragged, PackedFunc 
f_merge_inplace,
-                       PackedFunc f_split_rotary, PackedFunc 
f_copy_single_page,
-                       Optional<PackedFunc> f_debug_get_kv) {
+    .set_body([](TVMArgs args, TVMRetValue* rv) {
+      CHECK(args.size() == 19 || args.size() == 20 || args.size() == 21)
+          << "Invalid number of KV cache constructor args.";
+      ShapeTuple cache_config = args[0];
+      int64_t num_layers = args[1];
+      int64_t num_qo_heads = args[2];
+      int64_t num_kv_heads = args[3];
+      int64_t head_dim = args[4];
+      int rope_mode = args[5];
+      double rotary_scale = args[6];
+      double rotary_theta = args[7];
+      NDArray init = args[8];
+      PackedFunc f_transpose_append = args[9];
+      PackedFunc f_attention_prefill = args[10];
+      PackedFunc f_attention_decode = args[11];
+      PackedFunc f_attention_prefill_sliding_window = args[12];
+      PackedFunc f_attention_decode_sliding_window = args[13];
+      PackedFunc f_attention_prefill_ragged = args[14];
+      PackedFunc f_merge_inplace = args[15];
+      PackedFunc f_split_rotary = args[16];
+      PackedFunc f_copy_single_page = args[17];
+      Optional<PackedFunc> f_debug_get_kv = args[18];
+      PackedFunc f_compact_copy{nullptr};
+      PackedFunc f_attention_prefill_with_tree_mask{nullptr};
+
+      if (args.size() >= 20) {
+        f_compact_copy = args[19].AsObjectRef<PackedFunc>();
+      }
+      if (args.size() >= 21) {
+        f_attention_prefill_with_tree_mask = 
args[20].AsObjectRef<PackedFunc>();
+      }
+
       CHECK_EQ(cache_config.size(), 5);
       int64_t reserved_num_seqs = cache_config[0];
       int64_t total_token_capacity = cache_config[1];
@@ -1989,13 +2439,14 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced")
           page_size, num_layers, num_qo_heads, num_kv_heads, head_dim, 
reserved_num_seqs,
           num_total_pages, prefill_chunk_size, support_sliding_window, 
RoPEMode(rope_mode),
           rotary_scale, rotary_theta, init->dtype, init->device, 
std::move(f_transpose_append),
-          std::move(f_attention_prefill), std::move(f_attention_decode),
+          std::move(f_compact_copy), std::move(f_attention_prefill), 
std::move(f_attention_decode),
           std::move(f_attention_prefill_sliding_window),
-          std::move(f_attention_decode_sliding_window), 
std::move(f_attention_prefill_ragged),  //
-          NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt,                
                 //
+          std::move(f_attention_decode_sliding_window), 
std::move(f_attention_prefill_ragged),
+          std::move(f_attention_prefill_with_tree_mask),         //
+          NullOpt, NullOpt, NullOpt, NullOpt, NullOpt, NullOpt,  //
           std::move(f_merge_inplace), std::move(f_split_rotary), 
std::move(f_copy_single_page),
           std::move(f_debug_get_kv));
-      return AttentionKVCache(std::move(n));
+      *rv = AttentionKVCache(std::move(n));
     });
 
 }  // namespace relax_vm
diff --git a/src/runtime/relax_vm/rnn_state.cc 
b/src/runtime/relax_vm/rnn_state.cc
index 69225d6b2c..16fe6791b8 100644
--- a/src/runtime/relax_vm/rnn_state.cc
+++ b/src/runtime/relax_vm/rnn_state.cc
@@ -205,10 +205,24 @@ class RNNStateImpObj : public RNNStateObj {
 
   /************** Interaction **************/
 
-  void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths) {
+  void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths,
+                    const Optional<IntTuple>& opt_token_tree_parent_ptr) final 
{
     CHECK_EQ(seq_ids.size(), append_lengths.size())
         << "The seq_ids size (" << seq_ids.size() << ") and append_lengths 
size ("
         << append_lengths.size() << ") mismatch.";
+
+    if (opt_token_tree_parent_ptr.defined()) {
+      IntTuple token_tree_parent_ptr = opt_token_tree_parent_ptr.value();
+      int matched_pos = 0;
+      for (int64_t append_length : append_lengths) {
+        for (int64_t i = 0; i < append_length; ++i) {
+          CHECK_EQ(token_tree_parent_ptr[matched_pos], i - 1)
+              << "Unexpected token tree for RNN state. RNN state only supports 
chains as token "
+                 "trees.";
+          ++matched_pos;
+        }
+      }
+    }
     cur_batch_size_ = seq_ids.size();
     cur_append_lengths_ = append_lengths;
     cur_seq_ids_ = seq_ids;
diff --git 
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py 
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 6504175b56..0a69d184e5 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -53,6 +53,7 @@ fenable_sliding_window_for_seq = None
 fpopn = None
 fbegin_forward = None
 fend_forward = None
+fcommit_accepted_token_tree_nodes = None
 fattention_with_fuse_qkv = None
 fis_empty = None
 fdebug_get_kv = None
@@ -64,18 +65,22 @@ fattn_decode = None
 fattn_prefill_sliding_window = None
 fattn_decode_sliding_window = None
 fattn_prefill_ragged = None
+fattn_prefill_with_tree_mask = None
 fmerge_state = None
 fsplit_rotary = None
 fattention_rotary = None
 fcopy_single_page = None
+fcompact_copy = None
 
 
 def set_global_func(head_dim, dtype):
     global fclear, fadd_sequence, fremove_sequence, ffork_sequence, 
fenable_sliding_window_for_seq
-    global fpopn, fbegin_forward, fend_forward, fattention_with_fuse_qkv, 
fis_empty, fdebug_get_kv
-    global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode, 
fattn_prefill_ragged
+    global fpopn, fbegin_forward, fend_forward, 
fcommit_accepted_token_tree_nodes
+    global fattention_with_fuse_qkv, fis_empty, fdebug_get_kv
+    global ftranspose_append, fcopy_cache, fattn_prefill, fattn_decode
+    global fattn_prefill_ragged, fattn_prefill_with_tree_mask
     global fattn_prefill_sliding_window, fattn_decode_sliding_window
-    global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page
+    global fmerge_state, fsplit_rotary, fattention_rotary, fcopy_single_page, 
fcompact_copy
 
     fclear = tvm.get_global_func("vm.builtin.kv_state_clear")
     fadd_sequence = tvm.get_global_func("vm.builtin.kv_state_add_sequence")
@@ -87,6 +92,9 @@ def set_global_func(head_dim, dtype):
     fpopn = tvm.get_global_func("vm.builtin.kv_state_popn")
     fbegin_forward = tvm.get_global_func("vm.builtin.kv_state_begin_forward")
     fend_forward = tvm.get_global_func("vm.builtin.kv_state_end_forward")
+    fcommit_accepted_token_tree_nodes = tvm.get_global_func(
+        "vm.builtin.attention_kv_cache_commit_accepted_token_tree_nodes"
+    )
     fattention_with_fuse_qkv = tvm.get_global_func(
         "vm.builtin.attention_kv_cache_attention_with_fused_qkv"
     )
@@ -103,11 +111,13 @@ def set_global_func(head_dim, dtype):
         _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype, True, 
target),
         _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, True, 
target),
         _attention_prefill_ragged(num_kv_heads, num_qo_heads, head_dim, dtype, 
target),
+        _attention_prefill_with_tree_mask(num_kv_heads, num_qo_heads, 
head_dim, dtype, target),
         _merge_state_inplace(num_qo_heads, head_dim, dtype, target),
         llama_rope_with_position_map(
             rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype
         ),
         _copy_single_page(num_kv_heads, page_size, head_dim, dtype, target),
+        _compact_kv_copy(num_kv_heads, head_dim, dtype, target),
     ]:
         mod = tvm.IRModule({"main": tir_func})
         with target:
@@ -123,9 +133,11 @@ def set_global_func(head_dim, dtype):
         fattn_prefill_sliding_window,
         fattn_decode_sliding_window,
         fattn_prefill_ragged,
+        fattn_prefill_with_tree_mask,
         fmerge_state,
         fsplit_rotary,
         fcopy_single_page,
+        fcompact_copy,
     ) = builts
 
 
@@ -159,6 +171,8 @@ def create_kv_cache(head_dim, dtype, rope_mode, 
support_sliding_window):
         fsplit_rotary,
         fcopy_single_page,
         fcopy_cache,
+        fcompact_copy,
+        fattn_prefill_with_tree_mask,
     )
     return cache
 
@@ -211,7 +225,7 @@ def verify_cached_kv(kv_cache, seq_ids, expected_k, 
expected_v):
         tvm.testing.assert_allclose(values.numpy(), values_expected, 
rtol=1e-3, atol=1e-3)
 
 
-def f_apply_rotary(x, offset, scale, theta):
+def f_apply_rotary(x, offset, scale, theta, offset_list: Optional[List[int]] = 
None):
     # x: (N, H, D)
     assert len(x.shape) == 3
     nfeat = x.shape[-1]
@@ -220,7 +234,11 @@ def f_apply_rotary(x, offset, scale, theta):
     y = np.concatenate([-x[:, :, nfeat_half:], x[:, :, :nfeat_half]], axis=-1)
 
     inv_freq = scale / (theta ** (np.arange(0, nfeat, 2).astype("float32") / 
nfeat))
-    t = np.arange(offset, offset + x.shape[0], dtype=inv_freq.dtype)
+    t = (
+        np.arange(offset, offset + x.shape[0], dtype=inv_freq.dtype)
+        if offset_list is None
+        else (np.array(offset_list, dtype=inv_freq.dtype) + offset)
+    )
     freqs = np.einsum("i,j->ij", t, inv_freq)
     emb = np.concatenate((freqs, freqs), axis=-1)
     cos_values = np.cos(emb)
@@ -237,6 +255,8 @@ def apply_attention(
     cached_v: Dict[int, np.ndarray],
     sliding_window_sizes: Optional[List[int]] = None,
     attn_sink_sizes: Optional[List[int]] = None,
+    token_tree_parent_ptr_list: Optional[List[List[int]]] = None,
+    accepted_leaf_indices: Optional[List[int]] = None,
 ) -> None:
     seq_ids = []
     append_lengths = []
@@ -263,14 +283,42 @@ def apply_attention(
             cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, 
head_dim), dtype)
             cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, 
head_dim), dtype)
 
-    fbegin_forward(kv_cache, ShapeTuple(seq_ids), ShapeTuple(append_lengths))
+    assert (token_tree_parent_ptr_list is None) == (accepted_leaf_indices is 
None)
+    flattened_token_tree_parent_ptr = None
+    token_tree_node_depths_list: List[Optional[List[int]]] = [None for _ in 
batch]
+    if token_tree_parent_ptr_list:
+        assert len(token_tree_node_depths_list) == len(seq_ids)
+        assert len(accepted_leaf_indices) == len(seq_ids)
+        flattened_token_tree_parent_ptr = []
+        for i, (token_tree_parent_ptr, append_length) in enumerate(
+            zip(token_tree_parent_ptr_list, append_lengths)
+        ):
+            assert len(token_tree_parent_ptr) == append_length
+            flattened_token_tree_parent_ptr += token_tree_parent_ptr
+            token_tree_node_depths = []
+            for parent in token_tree_parent_ptr:
+                token_tree_node_depths.append(
+                    0 if parent == -1 else token_tree_node_depths[parent] + 1
+                )
+            token_tree_node_depths_list[i] = token_tree_node_depths
+
+    fbegin_forward(
+        kv_cache,
+        ShapeTuple(seq_ids),
+        ShapeTuple(append_lengths),
+        (
+            ShapeTuple(flattened_token_tree_parent_ptr)
+            if flattened_token_tree_parent_ptr is not None
+            else None
+        ),
+    )
 
     global_new_q = np.zeros((num_layers, 0, num_qo_heads, head_dim), dtype)
     global_new_k = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype)
     global_new_v = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype)
 
     q_array = []
-    for seq_id, append_length in batch:
+    for i, (seq_id, append_length) in enumerate(batch):
         new_q = np.random.rand(num_layers, append_length, num_qo_heads, 
head_dim).astype(dtype)
         new_k = np.random.rand(num_layers, append_length, num_kv_heads, 
head_dim).astype(dtype)
         new_v = np.random.rand(num_layers, append_length, num_kv_heads, 
head_dim).astype(dtype)
@@ -285,7 +333,11 @@ def apply_attention(
                             new_k[l]
                             if rope_mode != RopeMode.NORMAL
                             else f_apply_rotary(
-                                new_k[l], cached_k[seq_id].shape[1], 
rope_scale, rope_theta
+                                new_k[l],
+                                cached_k[seq_id].shape[1],
+                                rope_scale,
+                                rope_theta,
+                                token_tree_node_depths_list[i],
                             )
                         )
                         for l in range(num_layers)
@@ -323,12 +375,26 @@ def apply_attention(
                     rope_offset,
                     rope_scale,
                     rope_theta,
+                    token_tree_node_depths_list[i],
                 )
             ).transpose(1, 0, 2)
             k_seq = (
                 cached_k[seq_id][layer_id]
                 if rope_mode != RopeMode.INLINE
-                else f_apply_rotary(cached_k[seq_id][layer_id], 0, rope_scale, 
rope_theta)
+                else f_apply_rotary(
+                    cached_k[seq_id][layer_id],
+                    0,
+                    rope_scale,
+                    rope_theta,
+                    (
+                        (
+                            list(range(rope_offset))
+                            + [depth + rope_offset for depth in 
token_tree_node_depths_list[i]]
+                        )
+                        if token_tree_node_depths_list[i] is not None
+                        else None
+                    ),
+                )
             ).transpose(1, 2, 0)
             v_seq = cached_v[seq_id][layer_id].transpose(1, 0, 2)
 
@@ -336,11 +402,23 @@ def apply_attention(
             v_seq = np.repeat(v_seq, num_qo_heads // num_kv_heads, axis=0)
             softmax_input = (q_seq.astype("float32") @ 
k_seq.astype("float32")) / np.sqrt(head_dim)
             softmax_shape = softmax_input.shape
+            assert softmax_shape[-2] == append_length
             length_diff = softmax_shape[-1] - softmax_shape[-2]
             assert length_diff >= 0
             mask = np.tril(
                 np.full_like(softmax_input, np.finfo("float32").max), 
k=length_diff
             ) + np.triu(np.full_like(softmax_input, np.finfo("float32").min), 
k=length_diff + 1)
+            if token_tree_parent_ptr_list is not None:
+                tree_mask = np.full(
+                    (append_length, append_length), np.finfo("float32").min, 
dtype="float32"
+                )
+                for i, parent in enumerate(token_tree_parent_ptr_list[i]):
+                    if parent != -1:
+                        tree_mask[i] = tree_mask[parent]
+                    tree_mask[i, i] = np.finfo("float32").max
+                tree_mask = np.broadcast_to(tree_mask, (num_qo_heads, 
*tree_mask.shape))
+                mask[:, :, length_diff:] = tree_mask
+
             softmax_input = np.minimum(softmax_input, mask)
 
             results = np.expand_dims(
@@ -359,6 +437,32 @@ def apply_attention(
             sum_length += append_length
     fend_forward(kv_cache)
 
+    if accepted_leaf_indices is not None:
+        fcommit_accepted_token_tree_nodes(kv_cache, 
ShapeTuple(accepted_leaf_indices))
+        for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate(
+            zip(accepted_leaf_indices, batch)
+        ):
+            tree_path = []
+            node = accepted_leaf_idx
+            while node != -1:
+                tree_path.append(node)
+                node = token_tree_parent_ptr_list[i][node]
+            offset = cached_k[seq_id].shape[1] - append_length
+            length_to_pop = append_length - len(tree_path)
+            assert 0 <= length_to_pop < append_length
+            for dst_pos, src_pos in enumerate(reversed(tree_path)):
+                if dst_pos == src_pos:
+                    continue
+                cached_k[seq_id][:, offset + dst_pos, ...] = cached_k[seq_id][
+                    :, offset + src_pos, ...
+                ]
+                cached_v[seq_id][:, offset + dst_pos, ...] = cached_v[seq_id][
+                    :, offset + src_pos, ...
+                ]
+            if length_to_pop > 0:
+                cached_k[seq_id] = cached_k[seq_id][:, :-length_to_pop, ...]
+                cached_v[seq_id] = cached_v[seq_id][:, :-length_to_pop, ...]
+
     for seq_id, _ in batch:
         if sliding_window_sizes is not None and len(sliding_window_sizes) > 
seq_id:
             sliding_window_size = sliding_window_sizes[seq_id]
@@ -618,6 +722,64 @@ def 
test_paged_attention_kv_cache_sliding_window(kv_cache_and_config):
         )
 
 
+@tvm.testing.requires_gpu
+@tvm.testing.requires_cuda
+def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config):
+    kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
+    if support_sliding_window and rope_mode == RopeMode.NORMAL:
+        # Normal RoPE mode under sliding window settings is not supported.
+        return
+    fclear(kv_cache)
+
+    cached_k = {}
+    cached_v = {}
+    # Prefill 4 sequences
+    apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], 
cached_k, cached_v)
+    # Tree attention
+    apply_attention(
+        kv_cache,
+        rope_mode,
+        [(0, 7), (1, 15), (2, 10), (3, 14)],
+        cached_k,
+        cached_v,
+        token_tree_parent_ptr_list=[
+            [-1, 0, 0, 1, 1, 2, 2],  # complete binary tree of height 3
+            [-1, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6],  # complete binary 
tree of height 4
+            [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8],  # chain of length 10
+            [-1, 0, 0, 1, 1, 2, 2, -1, 7, 7, 8, 8, 9, 9],  # two complete 
binary trees of height 3
+        ],
+        accepted_leaf_indices=[6, 11, 6, 13],
+    )
+    # Do 5 rounds of decode.
+    for _ in range(5):
+        apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)], 
cached_k, cached_v)
+
+    # Test the cases where all trees are chains.
+    fclear(kv_cache)
+    cached_k = {}
+    cached_v = {}
+    # Prefill 4 sequences
+    apply_attention(kv_cache, rope_mode, [(0, 10), (1, 20), (2, 30), (3, 40)], 
cached_k, cached_v)
+    # Tree attention
+    apply_attention(
+        kv_cache,
+        rope_mode,
+        [(0, 7), (1, 15), (2, 10), (3, 14)],
+        cached_k,
+        cached_v,
+        token_tree_parent_ptr_list=[
+            [-1, 0, 1, 2, 3, 4, 5],  # complete binary tree of height 7
+            [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13],  # chain of 
length 15
+            [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8],  # chain of length 10
+            [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],  # chain of length 
14
+        ],
+        accepted_leaf_indices=[2, 6, 6, 4],
+    )
+    # Do 5 rounds of decode.
+    for _ in range(5):
+        apply_attention(kv_cache, rope_mode, [(0, 1), (1, 1), (2, 1), (3, 1)], 
cached_k, cached_v)
+
+
 def kv_cache_transpose_append(head_dim, dtype):
     # undefined vars used
     @T.prim_func(check_well_formed=False)
@@ -1843,6 +2005,336 @@ def _attention_prefill_ragged(
     return sch.mod["main"].with_attr("tir.is_scheduled", 1)
 
 
+def _tree_mask(row, col, mask_ptr, offset, stride, kv_len):
+    return tir.all(col < kv_len, mask_ptr[offset + row * stride + col] == 1)
+
+
+def _attention_prefill_with_tree_mask(
+    h_kv, h_q, d, dtype, target: Target
+):  # pylint: disable=unused-argument
+    # pylint: disable=invalid-name,line-too-long
+    NUM_BLKS = 16
+    LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8)  # 8 bytes
+    group_size = h_q // h_kv
+    sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))
+
+    bdx = 32
+    num_warps = 4
+    tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d 
// 128, 1), d, 16
+    L_per_cta = tile_x // group_size
+
+    # Otherwise we would exceed maxComputeWorkgroupStorageSize
+    if (
+        str(target.kind) == "webgpu"
+        and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4
+    ):
+        tile_z = 8
+        num_warps = 2
+
+    # fmt: off
+    @T.prim_func
+    def batch_tree_attn(  # pylint: disable=too-many-branches
+        var_q: T.handle, # [total_len, h_q, d]
+        var_q_indptr: T.handle, # [batch_size + 1]
+        var_k: T.handle, # [total_len, h_kv, d]
+        var_v: T.handle, # [total_len, h_kv, d]
+        var_kv_indptr: T.handle, # [batch_size + 1], kv_indptr should be the 
same as q_indptr in this case
+        var_q_rope_position: T.handle, # [total_q_len]
+        var_mn_indptr: T.handle, # [batch_size + 1]
+        var_mask: T.handle, # [mn_indptr[batch_size]]
+        var_output: T.handle, # [total_len, h_q, d]
+        var_lse: T.handle, # [total_len, h_q]
+        rotary_mode: T.int32,
+        rope_scale: T.float32,
+        rope_theta: T.float32,
+        attn_score_scaling_factor: T.float32,
+        batch_size: T.int32,
+    ):
+        qo_len = T.int32(is_size_var=True)
+        kv_len = T.int32(is_size_var=True)
+        q_indptr_elem_offset = T.int32(is_size_var=True)
+        kv_indptr_elem_offset = T.int32(is_size_var=True)
+        q_rope_position_elem_offset = T.int32(is_size_var=True)
+        mn_indptr_elem_offset = T.int32(is_size_var=True)
+        mask_elem_offset = T.int32(is_size_var=True)
+        tree_size = T.int32(is_size_var=True)
+
+        q = T.match_buffer(var_q, (qo_len, h_q, d), dtype)
+        q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", 
elem_offset=q_indptr_elem_offset)
+        k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype)
+        v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype)
+        kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", 
elem_offset=kv_indptr_elem_offset)
+        q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), 
"int32", elem_offset=q_rope_position_elem_offset)
+        mn_indptr = T.match_buffer(var_mn_indptr, (batch_size + 1,), "int32", 
elem_offset=mn_indptr_elem_offset)
+        mask = T.match_buffer(var_mask, (tree_size,), "int32", 
elem_offset=mask_elem_offset)
+        output = T.match_buffer(var_output, (qo_len, h_q, d), dtype)
+        lse = T.match_buffer(var_lse, (qo_len, h_q), "float32")  # pylint: 
disable=unused-variable
+
+        # kernel code
+        for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"):
+            for lby in T.thread_binding(h_kv, thread="blockIdx.y"):
+                for lty in T.thread_binding(num_warps, thread="threadIdx.y"):
+                    for ltx in T.thread_binding(bdx, thread="threadIdx.x"):
+                        with T.block("attn"):
+                            bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby, 
lty, ltx])
+                            T.reads()
+                            T.writes()
+                            tile_id = _var("int32")
+                            batch_idx = _var("int32")
+                            batch_tiles = _var("int32")
+                            batch_rows = _var("int32")
+                            iterator = _var("int32")
+                            kv_chunk_len = _var("int32")
+
+                            Q_smem = T.alloc_buffer((tile_x, d), dtype, 
scope="shared")
+                            K_smem = T.alloc_buffer((tile_z, d), dtype, 
scope="shared")
+                            V_smem = T.alloc_buffer((tile_z, d), dtype, 
scope="shared")
+                            S_smem = T.alloc_buffer((tile_x, tile_z), 
"float32", scope="shared")
+
+                            S_local = T.alloc_buffer((tile_x, tile_z), 
"float32", scope="local")
+                            O_local = T.alloc_buffer((tile_x, d), "float32", 
scope="local")
+
+                            m_smem = T.alloc_buffer((tile_x, ), "float32", 
scope="shared")
+                            m_prev_smem = T.alloc_buffer((tile_x, ), 
"float32", scope="shared")
+                            d_smem = T.alloc_buffer((tile_x, ), "float32", 
scope="shared")
+
+                            m_new = T.alloc_buffer((math.ceil(tile_x / (bdx * 
num_warps)),), "float32", scope="local")
+                            m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx * 
num_warps)),), "float32", scope="local")
+                            d_new = T.alloc_buffer((math.ceil(tile_x / (bdx * 
num_warps)),), "float32", scope="local")
+
+                            ## get tile_no, batch_idx, batch_tiles, batch_rows
+                            tile_id[0] = bx
+                            batch_idx[0] = 0
+                            batch_rows[0] = (q_indptr[1] - q_indptr[0]) * 
group_size
+                            batch_tiles[0] = T.ceildiv(batch_rows[0], tile_x)
+                            while T.tvm_thread_invariant(batch_idx[0] < 
batch_size):
+                                # advance to next tile
+                                while tile_id[0] >= batch_tiles[0] and 
batch_idx[0] < batch_size:
+                                    tile_id[0] -= batch_tiles[0]
+                                    batch_idx[0] += 1
+                                    if batch_idx[0] < batch_size:
+                                        b_idx: T.int32 = batch_idx[0]
+                                        batch_rows[0] = (q_indptr[b_idx + 1] - 
q_indptr[b_idx]) * group_size
+                                        batch_tiles[0] = 
T.ceildiv(batch_rows[0], tile_x)
+
+                                if T.tvm_thread_invariant(batch_idx[0] < 
batch_size):
+                                    b_idx: T.int32 = batch_idx[0]
+                                    L_start: T.int32 = q_indptr[b_idx] + 
tile_id[0] * L_per_cta
+                                    H_qo_start: T.int32 = by * group_size
+
+                                    kv_chunk_len[0] = kv_indptr[b_idx + 1] - 
kv_indptr[b_idx]
+                                    T.tvm_storage_sync("shared")
+
+                                    # init states
+                                    for i in T.serial(T.ceildiv(tile_x, bdx * 
num_warps)):
+                                        row: T.int32 = i * bdx * num_warps + 
ty * bdx + tx
+                                        if row < tile_x:
+                                            m_smem[row] = -5e4
+                                            d_smem[row] = 1.0
+
+                                    for li, lj in T.grid(tile_x, tile_y):
+                                        with T.block("O_init"):
+                                            i, j = T.axis.remap("SS", [li, lj])
+                                            O_local[i, j] = 0.0
+                                    T.tvm_storage_sync("shared")
+
+                                    # Load Q from gmem to smem
+                                    for li, lj in T.grid(tile_x, tile_y):
+                                        with T.block("Q_load"):
+                                            i, j = T.axis.remap("SS", [li, lj])
+                                            T.reads()
+                                            T.writes()
+                                            cur_L = L_start + i // group_size
+                                            cur_H_qo = H_qo_start + i % 
group_size
+                                            if cur_L < q_indptr[b_idx + 1]:
+                                                Q_smem[i, j] = T.if_then_else(
+                                                    rotary_mode == 1,
+                                                    _rope(q, 
q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype),
+                                                    q[cur_L, cur_H_qo, j]
+                                                )
+                                            else:
+                                                Q_smem[i, j] = 0.0
+                                    T.tvm_storage_sync("shared")
+
+                                    for iterator in 
T.serial(T.ceildiv(kv_chunk_len[0], tile_z)):
+                                        L_kv_start: T.int32 = iterator * tile_z
+                                        L_kv_base: T.int32 = kv_indptr[b_idx]
+                                        for lz, ly in T.grid(tile_z, tile_y):
+                                            with T.block("KV_load"):
+                                                i, j = T.axis.remap("SS", [lz, 
ly])
+                                                T.reads()
+                                                T.writes()
+                                                cur_L = L_kv_base + L_kv_start 
+ i
+                                                if L_kv_start + i < 
kv_chunk_len[0]:
+                                                    K_smem[i, j] = 
T.if_then_else(
+                                                        rotary_mode == 1,
+                                                        _rope(k, 
q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, by, j), dtype),
+                                                        k[cur_L, by, j]
+                                                    )
+                                                    V_smem[i, j] = v[cur_L, 
by, j]
+                                                else:
+                                                    K_smem[i, j] = 0.0
+                                                    V_smem[i, j] = 0.0
+                                        T.tvm_storage_sync("shared")
+
+                                        # Compute S
+                                        with T.block():
+                                            for li, lj, lk in T.grid(tile_x, 
tile_z, tile_y):
+                                                with T.block("S_gemm"):
+                                                    i, j, k = 
T.axis.remap("SSR", [li, lj, lk])
+                                                    with T.init():
+                                                        S_local[i, j] = 0.0
+                                                    S_local[i, j] += 
T.cast(Q_smem[i, k], "float32") * T.cast(K_smem[j, k], "float32") * 
attn_score_scaling_factor * sm_scale
+                                        T.tvm_storage_sync("shared")
+                                        for li, lj in T.grid(tile_x, tile_z):
+                                            with T.block("S_store"):
+                                                i, j = T.axis.remap("SS", [li, 
lj])
+                                                S_smem[i, j] = S_local[i, j]
+                                        T.tvm_storage_sync("shared")
+
+                                        # Update S, m, d
+                                        for i in T.serial(T.ceildiv(tile_x, 
bdx * num_warps)):
+                                            row: T.int32 = i * bdx * num_warps 
+ ty * bdx + tx
+                                            if row < tile_x:
+                                                with T.block("update1"):
+                                                    m_prev[i] = m_smem[row]
+                                                    m_new[i] = m_smem[row]
+                                                    # mask out of kv_chunk_len 
S
+                                                    for j in T.serial(tile_z):
+                                                        if 
_tree_mask(row=tile_id[0] * L_per_cta + row // group_size,
+                                                                col=L_kv_start 
+ j,
+                                                                mask_ptr=mask,
+                                                                
offset=mn_indptr[b_idx],
+                                                                
stride=q_indptr[b_idx + 1] - q_indptr[b_idx],
+                                                                
kv_len=kv_chunk_len[0]):
+                                                            m_new[i] = 
T.max(m_new[i], S_smem[row, j])
+                                                    d_new[i] = d_smem[row] * 
T.exp2(m_prev[i] - m_new[i])
+
+                                        for i in T.serial(T.ceildiv(tile_x, 
bdx * num_warps)):
+                                            row: T.int32 = i * bdx * num_warps 
+ ty * bdx + tx
+                                            with T.block("update"):
+                                                for j in T.serial(tile_z):
+                                                    # this is to avoid sync 
inside condition branch
+                                                    if row < tile_x:
+                                                        if 
_tree_mask(row=tile_id[0] * L_per_cta + row // group_size,
+                                                                col=L_kv_start 
+ j,
+                                                                mask_ptr=mask,
+                                                                
offset=mn_indptr[b_idx],
+                                                                
stride=q_indptr[b_idx + 1] - q_indptr[b_idx],
+                                                                
kv_len=kv_chunk_len[0]):
+                                                            S_smem[row, j] = 
T.exp2(S_smem[row, j] - m_new[i])
+                                                        else:
+                                                            S_smem[row, j] = 
T.exp2(-5e4 - m_new[i])
+
+                                        for i in T.serial(T.ceildiv(tile_x, 
bdx * num_warps)):
+                                            row: T.int32 = i * bdx * num_warps 
+ ty * bdx + tx
+                                            if row < tile_x:
+                                                with T.block("update"):
+                                                    for j in T.serial(tile_z):
+                                                        d_new[i] += 
S_smem[row, j]
+                                                    m_smem[row] = m_new[i]
+                                                    d_smem[row] = d_new[i]
+                                                    m_prev_smem[row] = 
m_prev[i]
+                                        T.tvm_storage_sync("shared")
+
+                                        # Update O
+                                        with T.block():
+                                            for li, lj, lk in T.grid(tile_x, 
tile_y, tile_z):
+                                                with T.block("O_gemm"):
+                                                    i, j, k = 
T.axis.remap("SSR", [li, lj, lk])
+                                                    with T.init():
+                                                        O_local[i, j] *= 
T.exp2(m_prev_smem[i] - m_smem[i])
+                                                    O_local[i, j] += S_smem[i, 
k] * T.cast(V_smem[k, j], "float32")
+
+                                    # Store O from smem to gmem
+                                    for li, lj in T.grid(tile_x, tile_y):
+                                        with T.block("O_store"):
+                                            i, j = T.axis.remap("SS", [li, lj])
+                                            if L_start + i // group_size < 
q_indptr[b_idx + 1]:
+                                                output[L_start + i // 
group_size, H_qo_start + i % group_size, j] = O_local[i, j] / d_smem[i]
+
+                                    # Store LSE to gmem
+                                    for li in T.grid(tile_x):
+                                        with T.block("lse_store"):
+                                            i = T.axis.remap("S", [li])
+                                            if L_start + i // group_size < 
q_indptr[b_idx + 1]:
+                                                lse[L_start + i // group_size, 
H_qo_start + i % group_size] = m_smem[i] + T.log2(d_smem[i])
+
+                                    # move to next tile
+                                    tile_id[0] += NUM_BLKS
+    # fmt: on
+    # pylint: enable=line-too-long,invalid-name,too-many-branches
+    sch = tir.Schedule(batch_tree_attn)
+
+    def get_tile_size(x, y, t):
+        cnt = (x * y) // t
+        assert (x * y) % t == 0
+        tile_y = (int)(math.ceil(math.sqrt(cnt)))
+        while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt:
+            tile_y += 1
+        assert tile_y <= cnt
+        tile_x = cnt // tile_y
+        return tile_x, tile_y
+
+    def apply_to_qkv_load(sch: tir.Schedule, block):
+        loop_x, loop_y = sch.get_loops(block)[-2:]
+        loop = sch.fuse(loop_x, loop_y)
+        _, ty, tx, vec = sch.split(
+            loop, factors=[None, num_warps, bdx, LOAD_VEC], 
preserve_unit_iters=True
+        )
+        sch.bind(ty, "threadIdx.y")
+        sch.bind(tx, "threadIdx.x")
+        sch.vectorize(vec)
+
+    def apply_to_so_ewise(sch: tir.Schedule, block, tile):
+        loop_x, loop_y = sch.get_loops(block)[-2:]
+        xo, xi = sch.split(loop_x, factors=[None, tile[0]])
+        yo, yi = sch.split(loop_y, factors=[None, tile[1]])
+        sch.reorder(xo, yo, xi, yi)
+        t = sch.fuse(xo, yo)
+        ty, tx = sch.split(t, factors=[None, bdx])
+        sch.bind(ty, "threadIdx.y")
+        sch.bind(tx, "threadIdx.x")
+
+    def apply_to_gemm(  # pylint: disable=unused-argument
+        sch: tir.Schedule, block, tile, read_0, read_1, r_len=8, k_major=False
+    ):
+        loop_x, loop_y, loop_z = sch.get_loops(block)[-3:]
+        xo, xi = sch.split(loop_x, factors=[None, tile[0]])
+        yo, yi = sch.split(loop_y, factors=[None, tile[1]])
+        sch.reorder(xo, yo, xi, yi)
+        t = sch.fuse(xo, yo)
+        ty, tx = sch.split(t, factors=[None, bdx])
+        sch.bind(ty, "threadIdx.y")
+        sch.bind(tx, "threadIdx.x")
+
+        ko, ki = sch.split(loop_z, factors=[None, r_len])
+        if k_major:
+            sch.reorder(ko, xi, yi, ki)
+        else:
+            sch.reorder(ko, ki, xi, yi)
+        sch.decompose_reduction(block, ty)
+
+    def apply_to_md(sch, block):
+        loop = sch.get_loops(block)[-1]
+        _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx])
+        sch.bind(ty, "threadIdx.y")
+        sch.bind(tx, "threadIdx.x")
+
+    tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps)
+    tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps)
+    apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True)
+    apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False)
+    apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s)
+    apply_to_so_ewise(sch, sch.get_block("O_init"), tile_o)
+    apply_to_so_ewise(sch, sch.get_block("O_store"), tile_o)
+    apply_to_qkv_load(sch, sch.get_block("Q_load"))
+    apply_to_qkv_load(sch, sch.get_block("KV_load"))
+
+    apply_to_md(sch, sch.get_block("lse_store"))
+    return sch.mod["main"].with_attr("tir.is_scheduled", 1)
+
+
 def _merge_state_inplace(
     num_heads, head_dim, v_dtype, target: Target
 ):  # pylint: disable=unused-argument
@@ -1960,6 +2452,56 @@ def _copy_single_page(num_heads, page_size, head_dim, 
dtype, target: Target):
     return copy_single_page
 
 
+def _compact_kv_copy(num_heads, head_dim, dtype, target: Target):
+    tx = 256 if str(target.kind) == "webgpu" else 1024
+
+    @T.prim_func
+    def compact_kv_copy(
+        var_pages: T.handle,
+        var_copy_length_indptr: T.handle,
+        var_copy_src_dst_pos: T.handle,
+        batch_size: T.int32,
+    ):
+        T.func_attr({"tir.is_scheduled": 1})
+        num_pages = T.int32()
+        total_copy_length = T.int32()
+        copy_length_indptr_elem_offset = T.int32()
+        copy_src_dst_pos_elem_offset = T.int32()
+        pages = T.match_buffer(var_pages, (num_pages, 2, num_heads, 16, 
head_dim), dtype)
+        copy_length_indptr = T.match_buffer(
+            var_copy_length_indptr,
+            (batch_size + 1,),
+            "int32",
+            elem_offset=copy_length_indptr_elem_offset,
+        )
+        copy_src_dst_pos = T.match_buffer(
+            var_copy_src_dst_pos,
+            (2, total_copy_length),
+            "int32",
+            elem_offset=copy_src_dst_pos_elem_offset,
+        )
+
+        for bhd_o in T.thread_binding(
+            (batch_size * num_heads * head_dim + tx - 1) // tx, 
thread="blockIdx.x"
+        ):
+            for bhd_i in T.thread_binding(tx, thread="threadIdx.x"):
+                b: T.int32 = (bhd_o * tx + bhd_i) // (num_heads * head_dim)
+                h: T.int32 = (bhd_o * tx + bhd_i) // head_dim % num_heads
+                d: T.int32 = (bhd_o * tx + bhd_i) % head_dim
+                if (bhd_o * tx + bhd_i) < batch_size * num_heads * head_dim:
+                    for i in T.serial(copy_length_indptr[b + 1] - 
copy_length_indptr[b]):
+                        src_pos: T.int32 = copy_src_dst_pos[0, 
copy_length_indptr[b] + i]
+                        dst_pos: T.int32 = copy_src_dst_pos[1, 
copy_length_indptr[b] + i]
+                        pages[dst_pos // 16, 0, h, dst_pos % 16, d] = pages[
+                            src_pos // 16, 0, h, src_pos % 16, d
+                        ]
+                        pages[dst_pos // 16, 1, h, dst_pos % 16, d] = pages[
+                            src_pos // 16, 1, h, src_pos % 16, d
+                        ]
+
+    return compact_kv_copy
+
+
 if __name__ == "__main__":
     HEAD_DIMS = [64, 128]
     DTYPES = ["float16", "float32"]
@@ -1976,3 +2518,4 @@ if __name__ == "__main__":
         test_paged_attention_kv_cache_fork_sequence(cache_and_config)
         test_paged_attention_kv_cache_popn(cache_and_config)
         test_paged_attention_kv_cache_sliding_window(cache_and_config)
+        test_paged_attention_kv_cache_tree_attn(cache_and_config)


Reply via email to