This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dc7d6873ba [Runtime] PagedKVCache execute data copy on a separate 
stream (#16692)
dc7d6873ba is described below

commit dc7d6873badeabddf98824c807fefe4a1a45194b
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 10 17:28:27 2024 -0400

    [Runtime] PagedKVCache execute data copy on a separate stream (#16692)
    
    This PR enhances PagedKVCache with the copy stream separation.
    In detail, for CUDA and ROCm backend, we create a standalone copy
    stream for the copy of auxiliary data structure from CPU to GPU.
    Furthermore, we move the copy from BeginForward to Attention,
    which means it's no longer eagerly executed, instead, becoming
    lazily executed when Attention computation is needed.
    
    By making these changes, we are able to overlap the auxiliary
    data copy time (on the copy stream) with the model forward
    computation that happens before the first Attention. As a result,
    we can hide some of the copy latency.
    
    This PR also bumps the version of FlashInfer for the copy stream
    support.
---
 3rdparty/flashinfer                    |   2 +-
 src/runtime/relax_vm/paged_kv_cache.cc | 161 +++++++++++++++++++++------------
 2 files changed, 106 insertions(+), 57 deletions(-)

diff --git a/3rdparty/flashinfer b/3rdparty/flashinfer
index f1f6a0de4e..0d04571b61 160000
--- a/3rdparty/flashinfer
+++ b/3rdparty/flashinfer
@@ -1 +1 @@
-Subproject commit f1f6a0de4e595b777e29cc0dc370c15bd1d736fb
+Subproject commit 0d04571b614c944b5831d080882107a98b9c6e65
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc 
b/src/runtime/relax_vm/paged_kv_cache.cc
index 6dec511f2f..fb22d20fcf 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -242,7 +242,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj 
{
   //-------------------------------------------
   /*!
    * \brief A boolean flag indicating if the auxiliary arrays are dirty.
-   * If it is dirty, an explicit "SyncAuxArrayToDevice" should be invoked.
+   * If it is dirty, an explicit "ComputeStreamWaitForCopyStream" should be 
invoked.
    */
   bool dirty_aux_data_device_ = false;
   /*! \brief The batch size of the current round of forwarding. */
@@ -285,6 +285,20 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   NDArray merged_attn_scores_device_;
   std::vector<NDArray> temp_attn_workspace_;
 
+  //-------------------------------------------
+  // Below are the auxiliary data structure on CPU.
+  // We make them class members to avoid repetitive allocation time in 
BeginForward.
+  //-------------------------------------------
+  std::vector<std::vector<int32_t>> qo_indptr_on_depths_host_;
+  std::vector<std::vector<int32_t>> page_indptr_on_depths_host_;
+  std::vector<std::vector<int32_t>> page_indices_on_depths_host_;
+  std::vector<std::vector<int32_t>> last_page_len_on_depths_host_;
+  std::vector<std::vector<int32_t>> k_rope_pos_offset_on_depths_host_;
+  std::vector<int32_t> k_ragged_rope_pos_offset_host_;
+  std::vector<int32_t> q_rope_position_map_host_;
+  std::vector<int32_t> append_position_map_host_;
+  std::vector<int32_t> cur_append_lengths_indptr_host_;
+
   //-------------------------------------------
   // For efficient memory management, the actual sizes of the arrays
   // above are over allocated.
@@ -328,6 +342,12 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
   std::vector<bool> use_decode_kernel_;
   /*! \brief Whether the attention request is a decode request, set in 
BeginForwardFunction. */
   bool is_decode_request_;
+  /*! \brief The device this PagedKVCache runs on. */
+  DLDevice device_;
+  /*! \brief The device stream for the default computation operations. */
+  TVMStreamHandle compute_stream_ = nullptr;
+  /*! \brief The device stream for copying auxiliary data structure to GPU. */
+  TVMStreamHandle copy_stream_ = nullptr;
 
  public:
   /*! \brief Constructor. Take the cache configuration and initialize the 
NDArrays. */
@@ -370,7 +390,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj 
{
         f_merge_inplace_(std::move(f_merge_inplace)),
         f_split_rotary_(std::move(f_split_rotary)),
         f_rotary_inplace_(std::move(f_rotary_inplace)),
-        f_debug_get_kv_(std::move(f_debug_get_kv)) {
+        f_debug_get_kv_(std::move(f_debug_get_kv)),
+        device_(device) {
     pages_.reserve(num_layers);
     for (int i = 0; i < num_layers; ++i) {
       pages_.push_back(
@@ -417,6 +438,22 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     for (int64_t page_id = num_total_pages - 1; page_id >= 0; --page_id) {
       free_page_ids_.push_back(page_id);
     }
+
+    // The compute stream is the default stream.
+    // If the device is CUDA/ROCm, we create a standalone copy stream, in
+    // purpose to hide the latency of auxiliary stream copy.
+    compute_stream_ = DeviceAPI::Get(device)->GetCurrentStream(device);
+    if (device.device_type == DLDeviceType::kDLCUDA ||
+        device.device_type == DLDeviceType::kDLROCM) {
+      copy_stream_ = DeviceAPI::Get(device)->CreateStream(device);
+    }
+  }
+
+  ~PagedAttentionKVCacheObj() {
+    // Free the copy stream if defined.
+    if (copy_stream_ != nullptr) {
+      DeviceAPI::Get(device_)->FreeStream(device_, copy_stream_);
+    }
   }
 
   /*! \brief Reset the KV cache. */
@@ -522,16 +559,15 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
 
     // - Collect sequence/block/page information for attention.
     std::vector<const Sequence*> sequences;
-    std::vector<int32_t> k_ragged_rope_pos_offset;
     is_decode_request_ = true;
     sequences.reserve(cur_batch_size_);
-    k_ragged_rope_pos_offset.reserve(cur_batch_size_);
+    k_ragged_rope_pos_offset_host_.resize(cur_batch_size_);
     for (int i = 0; i < cur_batch_size_; ++i) {
       auto it = seq_map_.find(seq_ids[i]);
       CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i]
                                   << "\" cannot be found in KV cache.";
       sequences.push_back(&it->second);
-      k_ragged_rope_pos_offset.push_back(it->second.seq_length);
+      k_ragged_rope_pos_offset_host_[i] = it->second.seq_length;
       it->second.seq_length += append_lengths[i];
       if (append_lengths[i] != 1) {
         is_decode_request_ = false;
@@ -561,18 +597,25 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       }
     }
 
-    std::vector<std::vector<int32_t>> qo_indptr_on_depths;
-    std::vector<std::vector<int32_t>> page_indptr_on_depths;
-    std::vector<std::vector<int32_t>> page_indices_on_depths;
-    std::vector<std::vector<int32_t>> last_page_len_on_depths;
-    std::vector<std::vector<int32_t>> k_rope_pos_offset_on_depths;
+    qo_indptr_on_depths_host_.resize(num_depths_);
+    page_indptr_on_depths_host_.resize(num_depths_);
+    page_indices_on_depths_host_.resize(num_depths_);
+    last_page_len_on_depths_host_.resize(num_depths_);
+    k_rope_pos_offset_on_depths_host_.resize(num_depths_);
 
     for (int d = 0; d < num_depths_; ++d) {
-      std::vector<int32_t> qo_indptr_h{0};
-      std::vector<int32_t> page_indptr_h{0};
-      std::vector<int32_t> page_indices_h;
-      std::vector<int32_t> last_page_len_h;
-      std::vector<int32_t> k_rope_pos_offset_h;
+      std::vector<int32_t>& qo_indptr_h = qo_indptr_on_depths_host_[d];
+      std::vector<int32_t>& page_indptr_h = page_indptr_on_depths_host_[d];
+      std::vector<int32_t>& page_indices_h = page_indices_on_depths_host_[d];
+      std::vector<int32_t>& last_page_len_h = last_page_len_on_depths_host_[d];
+      std::vector<int32_t>& k_rope_pos_offset_h = 
k_rope_pos_offset_on_depths_host_[d];
+      qo_indptr_h.clear();
+      page_indptr_h.clear();
+      page_indices_h.clear();
+      last_page_len_h.clear();
+      k_rope_pos_offset_h.clear();
+      qo_indptr_h.push_back(0);
+      page_indptr_h.push_back(0);
       for (const auto& [block_id, chunk_append_length] : 
chunked_block_ids_arr[d]) {
         qo_indptr_h.push_back(qo_indptr_h.back() + chunk_append_length);
         if (block_id == -1) {
@@ -588,11 +631,6 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
           k_rope_pos_offset_h.push_back(block.start_pos);
         }
       }
-      qo_indptr_on_depths.push_back(qo_indptr_h);
-      page_indptr_on_depths.push_back(page_indptr_h);
-      page_indices_on_depths.push_back(page_indices_h);
-      last_page_len_on_depths.push_back(last_page_len_h);
-      k_rope_pos_offset_on_depths.push_back(k_rope_pos_offset_h);
     }
 
     if (!append_before_attn_) {
@@ -606,28 +644,18 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
 
     // Map each the token position in the input batch to the position
     // in the global KV cache. The mapping is used in when appending k/v 
values.
-    std::vector<int32_t> q_rope_position_map;
-    std::vector<int32_t> append_position_map;
+    q_rope_position_map_host_.clear();
+    append_position_map_host_.clear();
     for (int i = 0; i < cur_batch_size_; ++i) {
       int64_t append_length = append_lengths[i];
       const Block& block = global_block_pool_[sequences[i]->last_block_idx];
       for (int64_t pos = 0; pos < append_length; ++pos) {
         int64_t pos_in_block = block.seq_length - append_length + pos;
-        q_rope_position_map.push_back(sequences[i]->seq_length - append_length 
+ pos);
-        append_position_map.push_back(block.page_ids[pos_in_block / 
page_size_] * page_size_ +
-                                      pos_in_block % page_size_);
+        q_rope_position_map_host_.push_back(sequences[i]->seq_length - 
append_length + pos);
+        append_position_map_host_.push_back(block.page_ids[pos_in_block / 
page_size_] * page_size_ +
+                                            pos_in_block % page_size_);
       }
     }
-
-    // - Sync NDArrays to GPU.
-    SyncAuxArrayToDevice(std::move(qo_indptr_on_depths), 
std::move(page_indptr_on_depths),
-                         std::move(page_indices_on_depths), 
std::move(last_page_len_on_depths),
-                         std::move(k_rope_pos_offset_on_depths),
-                         std::move(k_ragged_rope_pos_offset), 
std::move(q_rope_position_map),
-                         std::move(append_position_map));
-
-    // NOTE(Zihao): This logic is problematic ATM because we need a unique 
split per depth
-    KernelBeginForward();
   }
 
   void EndForward() final {
@@ -635,9 +663,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj 
{
         !f_attention_prefill_ragged_end_forward_.defined()) {
       return;
     }
-    // Mark the dirty flag as true, so that BeginForward is required
-    // to be invoked before the next round of model forward.
-    dirty_aux_data_device_ = true;
     f_attention_prefill_ragged_end_forward_.value()();
     for (int d = 0; d < num_depths_; ++d) {
       f_attention_prefill_end_forward_.value()(d);
@@ -681,10 +706,10 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       total_seq_length += cur_append_lengths_[seq_id];
     }
     CHECK_EQ(total_seq_length, q_data->shape[0]);
+    // Sync the copy stream and the compute stream.
+    ComputeStreamWaitForCopyStream();
     // The auxiliary data structure on device must have been synchronized.
-    CHECK(!dirty_aux_data_device_)
-        << "The auxiliary arrays are not synchronized to device. Please call "
-           "`BeginForward` to synchronize before calling `Attention`.";
+    ICHECK(!dirty_aux_data_device_);
 
     if (rope_mode_ == RoPEMode::kNormal) {
       // Apply rotary embedding to q/k data.
@@ -726,10 +751,10 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       total_seq_length += cur_append_lengths_[seq_id];
     }
     CHECK_EQ(total_seq_length, qkv_data->shape[0]);
+    // Sync the copy stream and the compute stream.
+    ComputeStreamWaitForCopyStream();
     // The auxiliary data structure on device must have been synchronized.
-    CHECK(!dirty_aux_data_device_)
-        << "The auxiliary arrays are not synchronized to device. Please call "
-           "`BeginForward` to synchronize before calling `Attention`.";
+    ICHECK(!dirty_aux_data_device_);
 
     NDArray q_data = temp_attn_q_device_.CreateView({total_seq_length, 
num_qo_heads_, head_dim_},
                                                     qkv_data->dtype);
@@ -965,11 +990,11 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       f_attention_decode_begin_forward_.value()(
           /*depth=*/0, temp_attn_workspace_[1], page_indptr_on_depths_view_[0],
           last_page_len_on_depths_view_[0], num_qo_heads_, num_kv_heads_, 
head_dim_, page_size_,
-          /*rotary_mode=*/rope_mode_ == RoPEMode::kInline);
+          /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
     } else {
       f_attention_prefill_ragged_begin_forward_.value()(
           temp_attn_workspace_[0], cur_append_length_indptr_view_, 
cur_batch_size_, num_qo_heads_,
-          num_kv_heads_);
+          num_kv_heads_, head_dim_, copy_stream_);
       for (int d = 0; d < num_depths_; ++d) {
         if (page_indices_on_depths_view_[d]->shape[0] == 0) {
           continue;
@@ -978,11 +1003,12 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
           f_attention_decode_begin_forward_.value()(
               d, temp_attn_workspace_[d + 1], page_indptr_on_depths_view_[d],
               last_page_len_on_depths_view_[d], num_qo_heads_, num_kv_heads_, 
head_dim_, page_size_,
-              /*rotary_mode=*/rope_mode_ == RoPEMode::kInline);
+              /*rotary_mode=*/rope_mode_ == RoPEMode::kInline, copy_stream_);
         } else {
           f_attention_prefill_begin_forward_.value()(
               /*depth=*/d, temp_attn_workspace_[d + 1], 
qo_indptr_on_depths_view_[d],
-              last_page_len_on_depths_view_[d]->shape[0], num_qo_heads_, 
num_kv_heads_);
+              last_page_len_on_depths_view_[d]->shape[0], num_qo_heads_, 
num_kv_heads_, head_dim_,
+              copy_stream_);
         }
       }
     }
@@ -1041,6 +1067,28 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     }
   }
 
+  /*! \brief Synchronize the copy stream and the compute stream. */
+  void ComputeStreamWaitForCopyStream() {
+    if (!dirty_aux_data_device_) {
+      // If the auxiliary data is already synced, return and no need to sync 
again.
+      return;
+    }
+    // - Sync NDArrays to GPU.
+    SyncAuxArrayToDevice(qo_indptr_on_depths_host_, 
page_indptr_on_depths_host_,
+                         page_indices_on_depths_host_, 
last_page_len_on_depths_host_,
+                         k_rope_pos_offset_on_depths_host_, 
k_ragged_rope_pos_offset_host_,
+                         q_rope_position_map_host_, append_position_map_host_);
+    KernelBeginForward();
+    // - Clear the dirty flag.
+    dirty_aux_data_device_ = false;
+    // - If there is no particular copy stream, no action is needed.
+    if (copy_stream_ == nullptr) {
+      return;
+    }
+    // - Sync two streams.
+    DeviceAPI::Get(device_)->SyncStreamFromTo(device_, copy_stream_, 
compute_stream_);
+  }
+
   /*!
    * \brief Synchronize auxiliary arrays to device.
    * \note This method resets the dirty flag to false, and needs to be
@@ -1061,15 +1109,16 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     ICHECK_EQ(last_page_len_on_depths.size(), num_depths_);
     int64_t total_append_length = 0;
     int num_sequences = cur_append_lengths_.size();
-    std::vector<int32_t> cur_append_lengths_indptr{0};
-    for (int i = 0; i < static_cast<int>(cur_append_lengths_.size()); ++i) {
-      cur_append_lengths_indptr.push_back(cur_append_lengths_indptr.back() +
-                                          cur_append_lengths_[i]);
+    cur_append_lengths_indptr_host_.resize(num_sequences + 1);
+    cur_append_lengths_indptr_host_[0] = 0;
+    for (int i = 0; i < num_sequences; ++i) {
+      cur_append_lengths_indptr_host_[i + 1] =
+          cur_append_lengths_indptr_host_[i] + cur_append_lengths_[i];
     }
-    total_append_length = cur_append_lengths_indptr.back();
+    total_append_length = cur_append_lengths_indptr_host_.back();
     ICHECK_EQ(total_append_length, append_position_map.size());
 
-    auto fcopy_from_vec = [](NDArray array, int32_t* vec_data) {
+    auto fcopy_from_vec = [copy_stream = this->copy_stream_](NDArray array, 
int32_t* vec_data) {
       DLTensor copy_dst = *array.operator->();
       DLTensor copy_src;
       copy_src.data = vec_data;
@@ -1079,7 +1128,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
       copy_src.shape = array->shape;
       copy_src.strides = nullptr;
       copy_src.byte_offset = 0;
-      NDArray::CopyFromTo(&copy_src, &copy_dst);
+      NDArray::CopyFromTo(&copy_src, &copy_dst, copy_stream);
     };
 
     // 1. qo_indptr_on_depths
@@ -1126,7 +1175,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     // 6. cur_append_lengths_indptr
     cur_append_length_indptr_view_ =
         cur_append_length_indptr_device_.CreateView({num_sequences + 1}, 
dtype_aux_);
-    fcopy_from_vec(cur_append_length_indptr_view_, 
cur_append_lengths_indptr.data());
+    fcopy_from_vec(cur_append_length_indptr_view_, 
cur_append_lengths_indptr_host_.data());
 
     // 7. k_ragged_rope_pos_offset
     ICHECK_EQ(k_ragged_rope_pos_offset.size(), num_sequences);

Reply via email to