This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e9ddd47a80 [Unity] Add an API to create multiple kv caches with single 
allocation (#15064)
e9ddd47a80 is described below

commit e9ddd47a80e53b77bb99120e90e958efbceeab46
Author: Lite Ye <[email protected]>
AuthorDate: Sat Jun 10 10:54:38 2023 -0400

    [Unity] Add an API to create multiple kv caches with single allocation 
(#15064)
---
 src/runtime/relax_vm/lm_support.cc         | 45 ++++++++++++++++++++++++++++++
 src/runtime/relax_vm/memory_manager.cc     |  7 +++++
 tests/python/relax/test_runtime_builtin.py | 28 +++++++++++++++++--
 3 files changed, 78 insertions(+), 2 deletions(-)

diff --git a/src/runtime/relax_vm/lm_support.cc 
b/src/runtime/relax_vm/lm_support.cc
index cfc596d476..9b14161e67 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -40,6 +40,7 @@
 #include <tvm/runtime/logging.h>
 #include <tvm/runtime/memory.h>
 #include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/relax_vm/memory_manager.h>
 #include <tvm/runtime/relax_vm/vm.h>
 
 #include <cmath>
@@ -167,12 +168,56 @@ class AttentionKVCache : public ObjectRef {
 
 TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
 
+/*!
+ * \brief Create multiple kv caches with same shape, from single memory 
allocation.
+ * \param init_data The initial data to put into the cache. Ignored if 
init_fill_count is
+ *        less than 0.
+ * \param reserve_shape The shape of cache.
+ * \param init_fill_count The initial row to fill into
+ *        the cache.
+ * \param num_caches Number of caches to create.
+ */
+Array<AttentionKVCache> CreateMultipleKVCaches(NDArray init_data, ShapeTuple 
reserve_shape,
+                                               int init_fill_count, int 
num_caches) {
+  DLDataType dtype = init_data->dtype;
+
+  int64_t cache_size = (dtype.bits * dtype.lanes + 7) / 8;
+  for (const auto dim : reserve_shape) {
+    cache_size *= dim;
+  }
+
+  // Add padding to make each cache align to kAllocAlignment
+  using tvm::runtime::kAllocAlignment;
+  int64_t padding = (kAllocAlignment - cache_size % kAllocAlignment) % 
kAllocAlignment;
+  int64_t cache_offset = cache_size + padding;
+
+  Storage storage =
+      Storage(MemoryManager::GetOrCreateAllocator(init_data->device, 
AllocatorType::kNaive)
+                  ->Alloc(cache_offset * num_caches, kAllocAlignment, dtype));
+
+  Array<AttentionKVCache> result;
+  for (int i = 0; i < num_caches; ++i) {
+    auto c = make_object<AttentionKVCacheObj>();
+    c->data = storage->AllocNDArray(i * cache_offset, reserve_shape, dtype);
+    c->fill_count = 0;
+    if (init_fill_count > 0) {
+      c->Append(init_data);
+      c->fill_count = init_fill_count;
+    }
+    result.push_back(AttentionKVCache(c));
+  }
+  return result;
+}
+
 //-------------------------------------------------
 //  Register runtime functions
 //-------------------------------------------------
 TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create")
     .set_body_typed(AttentionKVCache::Create);
 
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create_multiple")
+    .set_body_typed(CreateMultipleKVCaches);
+
 AttentionKVCache AttentionKVCacheUpdate(AttentionKVCache cache, NDArray value) 
{
   cache->Update(value);
   return cache;
diff --git a/src/runtime/relax_vm/memory_manager.cc 
b/src/runtime/relax_vm/memory_manager.cc
index 339045f515..7eedad2e56 100644
--- a/src/runtime/relax_vm/memory_manager.cc
+++ b/src/runtime/relax_vm/memory_manager.cc
@@ -28,6 +28,7 @@
 
 #include "naive_allocator.h"
 #include "pooled_allocator.h"
+#include "tvm/runtime/memory.h"
 
 namespace tvm {
 namespace runtime {
@@ -58,6 +59,12 @@ void StorageObj::Deleter(Object* obj) {
   delete ptr;
 }
 
+Storage::Storage(Buffer buffer) {
+  auto n = make_object<StorageObj>();
+  n->buffer = std::move(buffer);
+  data_ = std::move(n);
+}
+
 inline void VerifyDataType(DLDataType dtype) {
   ICHECK_GE(dtype.lanes, 1);
   if (dtype.code == kDLFloat) {
diff --git a/tests/python/relax/test_runtime_builtin.py 
b/tests/python/relax/test_runtime_builtin.py
index d25841a71f..682e9d712d 100644
--- a/tests/python/relax/test_runtime_builtin.py
+++ b/tests/python/relax/test_runtime_builtin.py
@@ -158,9 +158,9 @@ def test_attention_kv_cache():
     fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view")
 
     cache = fcreate(tvm.nd.empty((1, 2), dtype="int32"), 
tvm.runtime.ShapeTuple([2, 2]), 0)
-    num_steps = 0
+    num_steps = 2
     for i in range(num_steps):
-        cache = fappend(cache, tvm.nd.array(i * np.ones((1, 
2).astype("int32"))))
+        cache = fappend(cache, tvm.nd.array(i * np.ones((1, 
2)).astype("int32")))
 
     res = fview(cache, tvm.runtime.ShapeTuple((num_steps, 2))).numpy()
     for i in range(num_steps):
@@ -168,6 +168,30 @@ def test_attention_kv_cache():
         assert res[i][1] == i
 
 
+def test_attention_kv_cache_create_multiple():
+    fcreate = 
tvm.get_global_func("vm.builtin.attention_kv_cache_create_multiple")
+    fappend = tvm.get_global_func("vm.builtin.attention_kv_cache_append")
+    fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view")
+
+    num_caches = 4
+    cache_group = fcreate(
+        tvm.nd.empty((1, 2), dtype="int32"), tvm.runtime.ShapeTuple([7, 2]), 
0, num_caches
+    )
+
+    num_steps = 7
+    for i in range(num_steps):
+        for cache_index in range(num_caches):
+            fappend(
+                cache_group[cache_index],
+                tvm.nd.array(i * cache_index * np.ones((1, 
2)).astype("int32")),
+            )
+            res = fview(cache_group[cache_index], tvm.runtime.ShapeTuple((i + 
1, 2))).numpy()
+            # Also verify that the old values aren't corrupted
+            for j in range(i):
+                assert res[j][0] == j * cache_index
+                assert res[j][1] == j * cache_index
+
+
 def test_ndarray_cache():
     fload = tvm.get_global_func("vm.builtin.ndarray_cache.load")
     fget_params = tvm.get_global_func("vm.builtin.param_array_from_cache")

Reply via email to