This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e9ddd47a80 [Unity] Add an API to create multiple kv caches with single
allocation (#15064)
e9ddd47a80 is described below
commit e9ddd47a80e53b77bb99120e90e958efbceeab46
Author: Lite Ye <[email protected]>
AuthorDate: Sat Jun 10 10:54:38 2023 -0400
[Unity] Add an API to create multiple kv caches with single allocation
(#15064)
---
src/runtime/relax_vm/lm_support.cc | 45 ++++++++++++++++++++++++++++++
src/runtime/relax_vm/memory_manager.cc | 7 +++++
tests/python/relax/test_runtime_builtin.py | 28 +++++++++++++++++--
3 files changed, 78 insertions(+), 2 deletions(-)
diff --git a/src/runtime/relax_vm/lm_support.cc
b/src/runtime/relax_vm/lm_support.cc
index cfc596d476..9b14161e67 100644
--- a/src/runtime/relax_vm/lm_support.cc
+++ b/src/runtime/relax_vm/lm_support.cc
@@ -40,6 +40,7 @@
#include <tvm/runtime/logging.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/relax_vm/memory_manager.h>
#include <tvm/runtime/relax_vm/vm.h>
#include <cmath>
@@ -167,12 +168,56 @@ class AttentionKVCache : public ObjectRef {
TVM_REGISTER_OBJECT_TYPE(AttentionKVCacheObj);
+/*!
+ * \brief Create multiple kv caches with same shape, from single memory
allocation.
+ * \param init_data The initial data to put into the cache. Ignored if
init_fill_count is
+ * less than 0.
+ * \param reserve_shape The shape of cache.
+ * \param init_fill_count The initial row to fill into
+ * the cache.
+ * \param num_caches Number of caches to create.
+ */
+Array<AttentionKVCache> CreateMultipleKVCaches(NDArray init_data, ShapeTuple
reserve_shape,
+ int init_fill_count, int
num_caches) {
+ DLDataType dtype = init_data->dtype;
+
+ int64_t cache_size = (dtype.bits * dtype.lanes + 7) / 8;
+ for (const auto dim : reserve_shape) {
+ cache_size *= dim;
+ }
+
+ // Add padding to make each cache align to kAllocAlignment
+ using tvm::runtime::kAllocAlignment;
+ int64_t padding = (kAllocAlignment - cache_size % kAllocAlignment) %
kAllocAlignment;
+ int64_t cache_offset = cache_size + padding;
+
+ Storage storage =
+ Storage(MemoryManager::GetOrCreateAllocator(init_data->device,
AllocatorType::kNaive)
+ ->Alloc(cache_offset * num_caches, kAllocAlignment, dtype));
+
+ Array<AttentionKVCache> result;
+ for (int i = 0; i < num_caches; ++i) {
+ auto c = make_object<AttentionKVCacheObj>();
+ c->data = storage->AllocNDArray(i * cache_offset, reserve_shape, dtype);
+ c->fill_count = 0;
+ if (init_fill_count > 0) {
+ c->Append(init_data);
+ c->fill_count = init_fill_count;
+ }
+ result.push_back(AttentionKVCache(c));
+ }
+ return result;
+}
+
//-------------------------------------------------
// Register runtime functions
//-------------------------------------------------
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create")
.set_body_typed(AttentionKVCache::Create);
+TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_create_multiple")
+ .set_body_typed(CreateMultipleKVCaches);
+
AttentionKVCache AttentionKVCacheUpdate(AttentionKVCache cache, NDArray value)
{
cache->Update(value);
return cache;
diff --git a/src/runtime/relax_vm/memory_manager.cc
b/src/runtime/relax_vm/memory_manager.cc
index 339045f515..7eedad2e56 100644
--- a/src/runtime/relax_vm/memory_manager.cc
+++ b/src/runtime/relax_vm/memory_manager.cc
@@ -28,6 +28,7 @@
#include "naive_allocator.h"
#include "pooled_allocator.h"
+#include "tvm/runtime/memory.h"
namespace tvm {
namespace runtime {
@@ -58,6 +59,12 @@ void StorageObj::Deleter(Object* obj) {
delete ptr;
}
+Storage::Storage(Buffer buffer) {
+ auto n = make_object<StorageObj>();
+ n->buffer = std::move(buffer);
+ data_ = std::move(n);
+}
+
inline void VerifyDataType(DLDataType dtype) {
ICHECK_GE(dtype.lanes, 1);
if (dtype.code == kDLFloat) {
diff --git a/tests/python/relax/test_runtime_builtin.py
b/tests/python/relax/test_runtime_builtin.py
index d25841a71f..682e9d712d 100644
--- a/tests/python/relax/test_runtime_builtin.py
+++ b/tests/python/relax/test_runtime_builtin.py
@@ -158,9 +158,9 @@ def test_attention_kv_cache():
fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view")
cache = fcreate(tvm.nd.empty((1, 2), dtype="int32"),
tvm.runtime.ShapeTuple([2, 2]), 0)
- num_steps = 0
+ num_steps = 2
for i in range(num_steps):
- cache = fappend(cache, tvm.nd.array(i * np.ones((1,
2).astype("int32"))))
+ cache = fappend(cache, tvm.nd.array(i * np.ones((1,
2)).astype("int32")))
res = fview(cache, tvm.runtime.ShapeTuple((num_steps, 2))).numpy()
for i in range(num_steps):
@@ -168,6 +168,30 @@ def test_attention_kv_cache():
assert res[i][1] == i
+def test_attention_kv_cache_create_multiple():
+ fcreate =
tvm.get_global_func("vm.builtin.attention_kv_cache_create_multiple")
+ fappend = tvm.get_global_func("vm.builtin.attention_kv_cache_append")
+ fview = tvm.get_global_func("vm.builtin.attention_kv_cache_view")
+
+ num_caches = 4
+ cache_group = fcreate(
+ tvm.nd.empty((1, 2), dtype="int32"), tvm.runtime.ShapeTuple([7, 2]),
0, num_caches
+ )
+
+ num_steps = 7
+ for i in range(num_steps):
+ for cache_index in range(num_caches):
+ fappend(
+ cache_group[cache_index],
+ tvm.nd.array(i * cache_index * np.ones((1,
2)).astype("int32")),
+ )
+ res = fview(cache_group[cache_index], tvm.runtime.ShapeTuple((i +
1, 2))).numpy()
+ # Also verify that the old values aren't corrupted
+ for j in range(i):
+ assert res[j][0] == j * cache_index
+ assert res[j][1] == j * cache_index
+
+
def test_ndarray_cache():
fload = tvm.get_global_func("vm.builtin.ndarray_cache.load")
fget_params = tvm.get_global_func("vm.builtin.param_array_from_cache")