This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ae2ab58ad6 [KVCache] Fix the reference counter in sequence fork
(#16666)
ae2ab58ad6 is described below
commit ae2ab58ad682b963a93adddc7148bbad8154093e
Author: Ruihang Lai <[email protected]>
AuthorDate: Mon Mar 4 08:55:23 2024 -0500
[KVCache] Fix the reference counter in sequence fork (#16666)
This PR fixes a sequence reference counter bug in the KV cache:
when forking a child sequnece from an existing parent sequence,
the reference counter of hte parent sequence was not increased.
This leads to error when the child sequence is removed, where we
will check the parent's reference counter and find it is 0 and is
never changed unexpectedly.
Meanwhile, this PR updates the PagedKVCache tests with some latest
changes, including target-aware tile size selection.
---
src/runtime/relax_vm/paged_kv_cache.cc | 1 +
...runtime_builtin_paged_attention_kv_cache_tir.py | 290 +++++++++++++--------
2 files changed, 177 insertions(+), 114 deletions(-)
diff --git a/src/runtime/relax_vm/paged_kv_cache.cc
b/src/runtime/relax_vm/paged_kv_cache.cc
index f848ed2490..6dec511f2f 100644
--- a/src/runtime/relax_vm/paged_kv_cache.cc
+++ b/src/runtime/relax_vm/paged_kv_cache.cc
@@ -475,6 +475,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj
{
<< "Attention merge-score function not available. ForkSequence is
thereby not supported.";
int32_t parent_block_idx = parent_it->second.last_block_idx;
+ ++global_block_pool_[parent_block_idx].external_ref_cnt;
// Create a child block with the parent block pointer.
int32_t child_block_idx = GetFreeBlock();
global_block_pool_[child_block_idx].start_pos =
parent_it->second.seq_length;
diff --git
a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
index 365420dd12..34e9d51715 100644
--- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
+++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py
@@ -25,10 +25,12 @@ import scipy.special
import tvm
import tvm.testing
+from tvm import DataType
from tvm import dlight as dl
from tvm import tir
from tvm.runtime import ShapeTuple
from tvm.script import tir as T
+from tvm.target import Target
reserved_nseq = 32
maximum_total_seq_length = 1024
@@ -88,10 +90,10 @@ def set_global_func(head_dim, dtype):
for tir_func in [
kv_cache_transpose_append(head_dim, dtype),
copy_cache(head_dim, dtype),
- _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype),
- _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype),
- _attention_prefill_ragged(num_kv_heads, num_qo_heads, head_dim, dtype),
- _merge_state_inplace(num_qo_heads, head_dim, dtype),
+ _attention_prefill(num_kv_heads, num_qo_heads, head_dim, dtype,
target),
+ _attention_decode(num_kv_heads, num_qo_heads, head_dim, dtype, target),
+ _attention_prefill_ragged(num_kv_heads, num_qo_heads, head_dim, dtype,
target),
+ _merge_state_inplace(num_qo_heads, head_dim, dtype, target),
llama_rope_with_position_map(
rope_theta, rope_scale, head_dim, num_qo_heads, num_kv_heads, dtype
),
@@ -410,6 +412,12 @@ def
test_paged_attention_kv_cache_fork_sequence(kv_cache_and_rope_mode, fuse_qkv
for batch in operation_seq:
apply_attention(kv_cache, rope_mode, batch, cached_k, cached_v,
fuse_qkv)
+ for i in range(9, -1, -1):
+ fremove_sequence(kv_cache, i)
+ cached_k.pop(i)
+ cached_v.pop(i)
+ verify_cached_kv(kv_cache, seq_ids=list(range(i)),
expected_k=cached_k, expected_v=cached_v)
+
@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
@@ -517,7 +525,6 @@ def _inplace_rope(
num_kv_heads: int,
dtype: str,
):
- assert head_dim <= 128, "Rotary embedding currently only supports head_dim
<= 128"
rotary_dim = head_dim
def _rope(
@@ -714,17 +721,38 @@ def _var(dtype):
return T.alloc_buffer((1,), dtype, scope="local")
-def _attention_prefill(h_kv, h_q, d, dtype):
+def get_max_num_threads_per_block(target: Target):
+ """
+ max(max_num_threads, max_threads_per_block); if latter does not exist,
return max_num_threads.
+ We add this method since some targets have both fields and
`max_threads_per_block` is larger.
+ """
+ max_num_threads = target.max_num_threads
+ max_threads_per_block = target.attrs.get("max_threads_per_block", None)
+ if max_threads_per_block is None:
+ return max_num_threads
+ return max(max_num_threads, max_threads_per_block)
+
+
+def _attention_prefill(h_kv, h_q, d, dtype, target: Target): # pylint:
disable=unused-argument
# pylint: disable=invalid-name
NUM_BLKS = 16
- LOAD_VEC = 8 // ((tvm.runtime.DataType(dtype).bits + 7) // 8) # 8 bytes
+ LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes
group_size = h_q // h_kv
sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))
+ bdx = 32
num_warps = 4
- tile_x, tile_y, tile_z = 64 // ((tvm.DataType(dtype).bits + 7) // 8), d, 16
+ tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d
// 128, 1), d, 16
L_per_cta = tile_x // group_size
+ # Otherwise we would exceed maxComputeWorkgroupStorageSize
+ if (
+ str(target.kind) == "webgpu"
+ and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4
+ ):
+ tile_z = 8
+ num_warps = 2
+
def mask(causal, row, col, kv_len, qo_len):
return T.if_then_else(
causal > 0,
@@ -744,7 +772,7 @@ def _attention_prefill(h_kv, h_q, d, dtype):
var_page_values: T.handle, # [nnz_pages]
var_last_page_len: T.handle, # [b]
var_k_rope_pos_offset: T.handle, # [b]
- var_q_rope_position: T.handle, # [total_q_len]
+ var_q_rope_position: T.handle, # [total_len]
var_output: T.handle, # [total_len, h_q, d]
var_lse: T.handle, # [total_len, h_q]
causal: T.int32,
@@ -773,7 +801,7 @@ def _attention_prefill(h_kv, h_q, d, dtype):
for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"):
for lby in T.thread_binding(h_kv, thread="blockIdx.y"):
for lty in T.thread_binding(num_warps, thread="threadIdx.y"):
- for ltx in T.thread_binding(32, thread="threadIdx.x"):
+ for ltx in T.thread_binding(bdx, thread="threadIdx.x"):
with T.block("attn"):
bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby,
lty, ltx])
T.reads()
@@ -797,9 +825,9 @@ def _attention_prefill(h_kv, h_q, d, dtype):
m_prev_smem = T.alloc_buffer((tile_x, ),
"float32", scope="shared")
d_smem = T.alloc_buffer((tile_x, ), "float32",
scope="shared")
- m_new = T.alloc_buffer((math.ceil(tile_x / (32 *
num_warps)),), "float32", scope="local")
- m_prev = T.alloc_buffer((math.ceil(tile_x / (32 *
num_warps)),), "float32", scope="local")
- d_new = T.alloc_buffer((math.ceil(tile_x / (32 *
num_warps)),), "float32", scope="local")
+ m_new = T.alloc_buffer((math.ceil(tile_x / (bdx *
num_warps)),), "float32", scope="local")
+ m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx *
num_warps)),), "float32", scope="local")
+ d_new = T.alloc_buffer((math.ceil(tile_x / (bdx *
num_warps)),), "float32", scope="local")
## get tile_no, batch_idx, batch_tiles, batch_rows
tile_id[0] = bx
@@ -832,8 +860,8 @@ def _attention_prefill(h_kv, h_q, d, dtype):
T.tvm_storage_sync("shared")
# init states
- for i in T.serial(T.ceildiv(tile_x, 32 *
num_warps)):
- row: T.int32 = i * 32 * num_warps + ty
* 32 + tx
+ for i in T.serial(T.ceildiv(tile_x, bdx *
num_warps)):
+ row: T.int32 = i * bdx * num_warps +
ty * bdx + tx
if row < tile_x:
m_smem[row] = -5e4
d_smem[row] = 1.0
@@ -871,8 +899,8 @@ def _attention_prefill(h_kv, h_q, d, dtype):
T.writes()
cur_L = L_kv_start + i
if cur_L < kv_chunk_len[0]:
- page_no:
T.int32(is_size_var=True) = page_values[cur_page_indptr_begin +
T.floordiv(cur_L, 16)]
- page_offset:
T.int32(is_size_var=True) = T.floormod(cur_L, 16)
+ page_no:
T.int32(is_size_var=True) = page_values[cur_page_indptr_begin +
T.floordiv(cur_L, 16)] # type: ignore
+ page_offset:
T.int32(is_size_var=True) = T.floormod(cur_L, 16) # type: ignore
K_smem[i, j] =
T.if_then_else(
rotary_mode == 1,
_rope(pages,
k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (page_no, 0, by,
page_offset, j), dtype),
@@ -888,8 +916,8 @@ def _attention_prefill(h_kv, h_q, d, dtype):
T.writes()
cur_L = L_kv_start + i
if cur_L < kv_chunk_len[0]:
- page_no:
T.int32(is_size_var=True) = page_values[cur_page_indptr_begin +
T.floordiv(cur_L, 16)]
- page_offset:
T.int32(is_size_var=True) = T.floormod(cur_L, 16)
+ page_no:
T.int32(is_size_var=True) = page_values[cur_page_indptr_begin +
T.floordiv(cur_L, 16)] # type: ignore
+ page_offset:
T.int32(is_size_var=True) = T.floormod(cur_L, 16) # type: ignore
V_smem[i, j] =
pages[page_no, 1, by, page_offset, j]
else:
V_smem[i, j] = 0.0
@@ -911,8 +939,8 @@ def _attention_prefill(h_kv, h_q, d, dtype):
T.tvm_storage_sync("shared")
# Update S, m, d
- for i in T.serial(T.ceildiv(tile_x, 32
* num_warps)):
- row: T.int32 = i * 32 * num_warps
+ ty * 32 + tx
+ for i in T.serial(T.ceildiv(tile_x,
bdx * num_warps)):
+ row: T.int32 = i * bdx * num_warps
+ ty * bdx + tx
if row < tile_x:
with T.block("update1"):
m_prev[i] = m_smem[row]
@@ -927,8 +955,8 @@ def _attention_prefill(h_kv, h_q, d, dtype):
m_new[i] =
T.max(m_new[i], S_smem[row, j])
d_new[i] = d_smem[row] *
T.exp2(m_prev[i] - m_new[i])
- for i in T.serial(T.ceildiv(tile_x, 32
* num_warps)):
- row: T.int32 = i * 32 * num_warps
+ ty * 32 + tx
+ for i in T.serial(T.ceildiv(tile_x,
bdx * num_warps)):
+ row: T.int32 = i * bdx * num_warps
+ ty * bdx + tx
with T.block("update"):
for j in T.serial(tile_z):
# this is to avoid sync
inside condition branch
@@ -942,8 +970,8 @@ def _attention_prefill(h_kv, h_q, d, dtype):
else:
S_smem[row, j] =
T.exp2(-5e4 - m_new[i])
- for i in T.serial(T.ceildiv(tile_x, 32
* num_warps)):
- row: T.int32 = i * 32 * num_warps
+ ty * 32 + tx
+ for i in T.serial(T.ceildiv(tile_x,
bdx * num_warps)):
+ row: T.int32 = i * bdx * num_warps
+ ty * bdx + tx
if row < tile_x:
with T.block("update"):
for j in T.serial(tile_z):
@@ -986,7 +1014,7 @@ def _attention_prefill(h_kv, h_q, d, dtype):
cnt = (x * y) // t
assert (x * y) % t == 0
tile_y = (int)(math.ceil(math.sqrt(cnt)))
- while cnt % tile_y != 0 and y % tile_y != 0 and tile_y <= cnt:
+ while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt:
tile_y += 1
assert tile_y <= cnt
tile_x = cnt // tile_y
@@ -996,19 +1024,19 @@ def _attention_prefill(h_kv, h_q, d, dtype):
loop_x, loop_y = sch.get_loops(block)[-2:]
loop = sch.fuse(loop_x, loop_y)
_, ty, tx, vec = sch.split(
- loop, factors=[None, num_warps, 32, LOAD_VEC],
preserve_unit_iters=True
+ loop, factors=[None, num_warps, bdx, LOAD_VEC],
preserve_unit_iters=True
)
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
sch.vectorize(vec)
- def apply_to_so_ewise(sch: tir.Schedule, block, tile, vec_len=4):
+ def apply_to_so_ewise(sch: tir.Schedule, block, tile):
loop_x, loop_y = sch.get_loops(block)[-2:]
xo, xi = sch.split(loop_x, factors=[None, tile[0]])
yo, yi = sch.split(loop_y, factors=[None, tile[1]])
sch.reorder(xo, yo, xi, yi)
t = sch.fuse(xo, yo)
- ty, tx = sch.split(t, factors=[num_warps, 32])
+ ty, tx = sch.split(t, factors=[num_warps, bdx])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
@@ -1020,7 +1048,7 @@ def _attention_prefill(h_kv, h_q, d, dtype):
yo, yi = sch.split(loop_y, factors=[None, tile[1]])
sch.reorder(xo, yo, xi, yi)
t = sch.fuse(xo, yo)
- ty, tx = sch.split(t, factors=[num_warps, 32])
+ ty, tx = sch.split(t, factors=[num_warps, bdx])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
@@ -1033,12 +1061,12 @@ def _attention_prefill(h_kv, h_q, d, dtype):
def apply_to_md(sch, block):
loop = sch.get_loops(block)[-1]
- _, ty, tx = sch.split(loop, factors=[None, num_warps, 32])
+ _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
- tile_s = get_tile_size(tile_x, tile_z, 32 * num_warps)
- tile_o = get_tile_size(tile_x, tile_y, 32 * num_warps)
+ tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps)
+ tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps)
apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True)
apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False)
apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s)
@@ -1051,18 +1079,30 @@ def _attention_prefill(h_kv, h_q, d, dtype):
return sch.mod["main"].with_attr("tir.is_scheduled", 1)
-def _attention_decode(num_kv_heads, num_qo_heads, head_dim, qkv_dtype):
+def _attention_decode(
+ num_kv_heads,
+ num_qo_heads,
+ head_dim,
+ qkv_dtype,
+ target: Target, # pylint: disable=unused-argument
+):
# pylint: disable=invalid-name
qkv_dtype_bytes = 2
H_qo = num_qo_heads
H_kv = num_kv_heads
D = head_dim
+ max_num_threads_per_block = get_max_num_threads_per_block(target)
+ thread_limit = min(max_num_threads_per_block, 512)
+
GROUP_SIZE = H_qo // H_kv
- VEC_SIZE = max(8 // qkv_dtype_bytes, D // 32)
+ VEC_SIZE = min(max(8 // qkv_dtype_bytes, D // 32), 4)
bdx = D // VEC_SIZE
bdy = GROUP_SIZE
- threads_per_CTA = max(512, bdx * bdy)
+ while bdx * bdy > thread_limit and bdy > 1:
+ bdy //= 2
+ gdz = GROUP_SIZE // bdy
+ threads_per_CTA = max(thread_limit, bdx * bdy)
bdz = threads_per_CTA // (bdx * bdy)
tile_size_per_bdx = 2 if GROUP_SIZE == 1 else 1
log2e = math.log2(math.exp(1))
@@ -1106,7 +1146,7 @@ def _attention_decode(num_kv_heads, num_qo_heads,
head_dim, qkv_dtype):
sm_scale = 1.0 / math.sqrt(float(D)) * log2e
for bx in T.thread_binding(B, thread="blockIdx.x"):
- for by in T.thread_binding(H_kv, thread="blockIdx.y"):
+ for fused_by_bz in T.thread_binding(H_kv * gdz,
thread="blockIdx.y"):
for ty in T.thread_binding(bdy, thread="threadIdx.y"):
for tx in T.thread_binding(bdx, thread="threadIdx.x"):
for tz in T.thread_binding(bdz, thread="threadIdx.z"):
@@ -1132,6 +1172,8 @@ def _attention_decode(num_kv_heads, num_qo_heads,
head_dim, qkv_dtype):
st_d = T.alloc_buffer((1,), "float32",
scope="local")
O_local = T.alloc_buffer((VEC_SIZE,),
"float32", scope="local")
+ by: T.int32 = fused_by_bz % H_kv
+ bz: T.int32 = fused_by_bz // H_kv
batch_idx: T.int32 = bx
cur_page_indptr_begin: T.int32 =
page_table_indptr[batch_idx]
cur_page_indptr_end: T.int32 =
page_table_indptr[batch_idx + 1]
@@ -1152,19 +1194,19 @@ def _attention_decode(num_kv_heads, num_qo_heads,
head_dim, qkv_dtype):
for vec in T.vectorized(VEC_SIZE):
Q_local[vec] = T.if_then_else(
rotary_mode == 1,
- _rope(Q, q_rope_position[batch_idx],
head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + ty, tx * VEC_SIZE +
vec), qkv_dtype),
- Q[bx, by * GROUP_SIZE + ty, tx *
VEC_SIZE + vec]
+ _rope(Q, q_rope_position[batch_idx],
head_dim, rope_theta, rope_scale, (bx, by * GROUP_SIZE + bz * bdy + ty, tx *
VEC_SIZE + vec), qkv_dtype),
+ Q[bx, by * GROUP_SIZE + bz * bdy + ty,
tx * VEC_SIZE + vec]
)
for iterator in
T.serial(T.ceildiv(kv_chunk_len[0], tile_size_per_bdx * bdy * bdz)):
- tile_start_s: T.int32(is_size_var=True) =
(tz * bdy + ty) * tile_size_per_bdx
- tile_start_g: T.int32(is_size_var=True) =
((iterator * bdz + tz) * bdy + ty) * tile_size_per_bdx
+ tile_start_s: T.int32(is_size_var=True) =
(tz * bdy + ty) * tile_size_per_bdx # type: ignore
+ tile_start_g: T.int32(is_size_var=True) =
((iterator * bdz + tz) * bdy + ty) * tile_size_per_bdx # type: ignore
# load K from global memory to shared
memory
for j in T.serial(tile_size_per_bdx):
- row_g: T.int32(is_size_var=True) =
tile_start_g + j
+ row_g: T.int32(is_size_var=True) =
tile_start_g + j # type: ignore
if row_g < kv_chunk_len[0]:
- page_no: T.int32(is_size_var=True)
= page_table_values[cur_page_indptr_begin + T.floordiv(row_g, 16)]
- page_offset:
T.int32(is_size_var=True) = T.floormod(row_g, 16)
+ page_no: T.int32(is_size_var=True)
= page_table_values[cur_page_indptr_begin + T.floordiv(row_g, 16)] # type:
ignore
+ page_offset:
T.int32(is_size_var=True) = T.floormod(row_g, 16) # type: ignore
for vec in T.vectorized(VEC_SIZE):
K_smem[tile_start_s + j, tx *
VEC_SIZE + vec] = T.if_then_else(
rotary_mode == 1,
@@ -1177,10 +1219,10 @@ def _attention_decode(num_kv_heads, num_qo_heads,
head_dim, qkv_dtype):
T.tvm_storage_sync("shared")
# load V from global memory to shared
memory
for j in T.serial(tile_size_per_bdx):
- row_g: T.int32(is_size_var=True) =
tile_start_g + j
+ row_g: T.int32(is_size_var=True) =
tile_start_g + j # type: ignore
if row_g < kv_chunk_len[0]:
- page_no: T.int32(is_size_var=True)
= page_table_values[cur_page_indptr_begin + T.floordiv(row_g, 16)]
- page_offset:
T.int32(is_size_var=True) = T.floormod(row_g, 16)
+ page_no: T.int32(is_size_var=True)
= page_table_values[cur_page_indptr_begin + T.floordiv(row_g, 16)] # type:
ignore
+ page_offset:
T.int32(is_size_var=True) = T.floormod(row_g, 16) # type: ignore
for vec in T.vectorized(VEC_SIZE):
V_smem[tile_start_s + j, tx *
VEC_SIZE + vec] = pages[page_no, 1, by, page_offset, tx * VEC_SIZE + vec]
else:
@@ -1263,26 +1305,37 @@ def _attention_decode(num_kv_heads, num_qo_heads,
head_dim, qkv_dtype):
# store O to global memory
for vec in T.vectorized(VEC_SIZE):
- output[batch_idx, by * GROUP_SIZE + ty, tx
* VEC_SIZE + vec] = O_local[vec]
+ output[batch_idx, by * GROUP_SIZE + bz *
bdy + ty, tx * VEC_SIZE + vec] = O_local[vec]
# store lse to global memory
- lse[batch_idx, by * GROUP_SIZE + ty] = st_m[0]
+ T.log2(st_d[0])
+ lse[batch_idx, by * GROUP_SIZE + bz * bdy +
ty] = st_m[0] + T.log2(st_d[0])
# fmt: on
# pylint:
enable=line-too-long,invalid-name,too-many-arguments,too-many-branches
return batch_decode_paged_kv
-def _attention_prefill_ragged(h_kv, h_q, d, dtype):
- # pylint: disable=invalid-name
+def _attention_prefill_ragged(
+ h_kv, h_q, d, dtype, target: Target
+): # pylint: disable=unused-argument
+ # pylint: disable=invalid-name,line-too-long
NUM_BLKS = 16
- LOAD_VEC = 8 // ((tvm.DataType(dtype).bits + 7) // 8) # 8 bytes
+ LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes
group_size = h_q // h_kv
sm_scale = 1.0 / math.sqrt(float(d)) * math.log2(math.exp(1))
+ bdx = 32
num_warps = 4
- tile_x, tile_y, tile_z = 64 // ((tvm.DataType(dtype).bits + 7) // 8), d, 16
+ tile_x, tile_y, tile_z = 64 // ((DataType(dtype).bits + 7) // 8) // max(d
// 128, 1), d, 16
L_per_cta = tile_x // group_size
+ # Otherwise we would exceed maxComputeWorkgroupStorageSize
+ if (
+ str(target.kind) == "webgpu"
+ and ((d + 127) // 128) * ((DataType(dtype).bits + 15) // 16) >= 4
+ ):
+ tile_z = 8
+ num_warps = 2
+
def mask(causal, row, col, kv_len, qo_len):
return T.if_then_else(
causal > 0,
@@ -1292,7 +1345,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
# fmt: off
@T.prim_func
- def batch_prefill_ragged_kv(
+ def batch_prefill_ragged_kv( # pylint:
disable=too-many-arguments,too-many-branches
var_q: T.handle, # [total_len, h_q, d]
var_q_indptr: T.handle, # [batch_size + 1]
var_k: T.handle, # [total_len, h_kv, d]
@@ -1306,7 +1359,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
rotary_mode: T.int32,
rope_scale: T.float32,
rope_theta: T.float32,
- attn_score_scaling_factor: T.float32,
+ attn_score_scaling_factor: T.float32
):
batch_size = T.int32(is_size_var=True)
qo_len = T.int32(is_size_var=True)
@@ -1326,7 +1379,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
for lbx in T.thread_binding(NUM_BLKS, thread="blockIdx.x"):
for lby in T.thread_binding(h_kv, thread="blockIdx.y"):
for lty in T.thread_binding(num_warps, thread="threadIdx.y"):
- for ltx in T.thread_binding(32, thread="threadIdx.x"):
+ for ltx in T.thread_binding(bdx, thread="threadIdx.x"):
with T.block("attn"):
bx, by, ty, tx = T.axis.remap("SSSS", [lbx, lby,
lty, ltx])
T.reads()
@@ -1350,9 +1403,9 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
m_prev_smem = T.alloc_buffer((tile_x, ),
"float32", scope="shared")
d_smem = T.alloc_buffer((tile_x, ), "float32",
scope="shared")
- m_new = T.alloc_buffer((math.ceil(tile_x / (32 *
num_warps)),), "float32", scope="local")
- m_prev = T.alloc_buffer((math.ceil(tile_x / (32 *
num_warps)),), "float32", scope="local")
- d_new = T.alloc_buffer((math.ceil(tile_x / (32 *
num_warps)),), "float32", scope="local")
+ m_new = T.alloc_buffer((math.ceil(tile_x / (bdx *
num_warps)),), "float32", scope="local")
+ m_prev = T.alloc_buffer((math.ceil(tile_x / (bdx *
num_warps)),), "float32", scope="local")
+ d_new = T.alloc_buffer((math.ceil(tile_x / (bdx *
num_warps)),), "float32", scope="local")
## get tile_no, batch_idx, batch_tiles, batch_rows
tile_id[0] = bx
@@ -1378,8 +1431,8 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
T.tvm_storage_sync("shared")
# init states
- for i in T.serial(T.ceildiv(tile_x, 32 *
num_warps)):
- row: T.int32 = i * 32 * num_warps + ty
* 32 + tx
+ for i in T.serial(T.ceildiv(tile_x, bdx *
num_warps)):
+ row: T.int32 = i * bdx * num_warps +
ty * bdx + tx
if row < tile_x:
m_smem[row] = -5e4
d_smem[row] = 1.0
@@ -1454,8 +1507,8 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
T.tvm_storage_sync("shared")
# Update S, m, d
- for i in T.serial(T.ceildiv(tile_x, 32
* num_warps)):
- row: T.int32 = i * 32 * num_warps
+ ty * 32 + tx
+ for i in T.serial(T.ceildiv(tile_x,
bdx * num_warps)):
+ row: T.int32 = i * bdx * num_warps
+ ty * bdx + tx
if row < tile_x:
with T.block("update1"):
m_prev[i] = m_smem[row]
@@ -1470,8 +1523,8 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
m_new[i] =
T.max(m_new[i], S_smem[row, j])
d_new[i] = d_smem[row] *
T.exp2(m_prev[i] - m_new[i])
- for i in T.serial(T.ceildiv(tile_x, 32
* num_warps)):
- row: T.int32 = i * 32 * num_warps
+ ty * 32 + tx
+ for i in T.serial(T.ceildiv(tile_x,
bdx * num_warps)):
+ row: T.int32 = i * bdx * num_warps
+ ty * bdx + tx
with T.block("update"):
for j in T.serial(tile_z):
# this is to avoid sync
inside condition branch
@@ -1485,8 +1538,8 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
else:
S_smem[row, j] =
T.exp2(-5e4 - m_new[i])
- for i in T.serial(T.ceildiv(tile_x, 32
* num_warps)):
- row: T.int32 = i * 32 * num_warps
+ ty * 32 + tx
+ for i in T.serial(T.ceildiv(tile_x,
bdx * num_warps)):
+ row: T.int32 = i * bdx * num_warps
+ ty * bdx + tx
if row < tile_x:
with T.block("update"):
for j in T.serial(tile_z):
@@ -1529,7 +1582,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
cnt = (x * y) // t
assert (x * y) % t == 0
tile_y = (int)(math.ceil(math.sqrt(cnt)))
- while cnt % tile_y != 0 and y % tile_y != 0 and tile_y <= cnt:
+ while (cnt % tile_y != 0 or y % tile_y != 0) and tile_y <= cnt:
tile_y += 1
assert tile_y <= cnt
tile_x = cnt // tile_y
@@ -1539,19 +1592,19 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
loop_x, loop_y = sch.get_loops(block)[-2:]
loop = sch.fuse(loop_x, loop_y)
_, ty, tx, vec = sch.split(
- loop, factors=[None, num_warps, 32, LOAD_VEC],
preserve_unit_iters=True
+ loop, factors=[None, num_warps, bdx, LOAD_VEC],
preserve_unit_iters=True
)
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
sch.vectorize(vec)
- def apply_to_so_ewise(sch: tir.Schedule, block, tile, vec_len=4):
+ def apply_to_so_ewise(sch: tir.Schedule, block, tile):
loop_x, loop_y = sch.get_loops(block)[-2:]
xo, xi = sch.split(loop_x, factors=[None, tile[0]])
yo, yi = sch.split(loop_y, factors=[None, tile[1]])
sch.reorder(xo, yo, xi, yi)
t = sch.fuse(xo, yo)
- ty, tx = sch.split(t, factors=[num_warps, 32])
+ ty, tx = sch.split(t, factors=[num_warps, bdx])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
@@ -1563,7 +1616,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
yo, yi = sch.split(loop_y, factors=[None, tile[1]])
sch.reorder(xo, yo, xi, yi)
t = sch.fuse(xo, yo)
- ty, tx = sch.split(t, factors=[num_warps, 32])
+ ty, tx = sch.split(t, factors=[num_warps, bdx])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
@@ -1576,12 +1629,12 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
def apply_to_md(sch, block):
loop = sch.get_loops(block)[-1]
- _, ty, tx = sch.split(loop, factors=[None, num_warps, 32])
+ _, ty, tx = sch.split(loop, factors=[None, num_warps, bdx])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
- tile_s = get_tile_size(tile_x, tile_z, 32 * num_warps)
- tile_o = get_tile_size(tile_x, tile_y, 32 * num_warps)
+ tile_s = get_tile_size(tile_x, tile_z, bdx * num_warps)
+ tile_o = get_tile_size(tile_x, tile_y, bdx * num_warps)
apply_to_gemm(sch, sch.get_block("S_gemm"), tile_s, 0, 1, k_major=True)
apply_to_gemm(sch, sch.get_block("O_gemm"), tile_o, 2, 3, k_major=False)
apply_to_so_ewise(sch, sch.get_block("S_store"), tile_s)
@@ -1595,12 +1648,18 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype):
return sch.mod["main"].with_attr("tir.is_scheduled", 1)
-def _merge_state_inplace(num_heads, head_dim, v_dtype):
+def _merge_state_inplace(
+ num_heads, head_dim, v_dtype, target: Target
+): # pylint: disable=unused-argument
# pylint: disable=invalid-name
v_dtype_bytes = 2
- VEC_SIZE = max(8 // v_dtype_bytes, head_dim // 32)
+ VEC_SIZE = min(max(8 // v_dtype_bytes, head_dim // 32), 4)
bdx = head_dim // VEC_SIZE
bdy = num_heads
+ max_num_threads_per_block = get_max_num_threads_per_block(target)
+ while bdx * bdy > max_num_threads_per_block and bdy > 1:
+ bdy //= 2
+ gdy = num_heads // bdy
@T.prim_func
def merge_state_inplace(
@@ -1620,43 +1679,46 @@ def _merge_state_inplace(num_heads, head_dim, v_dtype):
S_other = T.match_buffer(s_other, (N, H), "float32")
for bx in T.thread_binding(N, thread="blockIdx.x"):
- for ty in T.thread_binding(bdy, thread="threadIdx.y"):
- for tx in T.thread_binding(bdx, thread="threadIdx.x"):
- with T.block("merge"):
- s_val = _var("float32")
- s_other_val = _var("float32")
- s_max = _var("float32")
- scale = _var("float32")
- other_scale = _var("float32")
-
- v_vec = T.alloc_buffer((VEC_SIZE,), v_dtype,
scope="local")
- v_other_vec = T.alloc_buffer((VEC_SIZE,), v_dtype,
scope="local")
-
- s_val[0] = S[bx, ty]
- s_other_val[0] = S_other[bx, ty]
- s_max[0] = T.max(s_val[0], s_other_val[0])
- s_val[0] = T.exp2(s_val[0] - s_max[0])
- s_other_val[0] = T.exp2(s_other_val[0] - s_max[0])
- scale[0] = s_val[0] / (s_val[0] + s_other_val[0])
- other_scale[0] = s_other_val[0] / (s_val[0] +
s_other_val[0])
-
- # load v
- for vec in T.vectorized(VEC_SIZE):
- v_vec[vec] = V[bx, ty, tx * VEC_SIZE + vec]
- # load v_other
- for vec in T.vectorized(VEC_SIZE):
- v_other_vec[vec] = V_other[bx, ty, tx * VEC_SIZE +
vec]
-
- # merge
- for vec in T.serial(VEC_SIZE):
- v_vec[vec] = v_vec[vec] * scale[0] +
v_other_vec[vec] * other_scale[0]
-
- # store v
- for vec in T.vectorized(VEC_SIZE):
- V[bx, ty, tx * VEC_SIZE + vec] = v_vec[vec]
-
- # store s
- S[bx, ty] = T.log2(s_val[0] + s_other_val[0]) +
s_max[0]
+ for by in T.thread_binding(gdy, thread="blockIdx.y"):
+ for ty in T.thread_binding(bdy, thread="threadIdx.y"):
+ for tx in T.thread_binding(bdx, thread="threadIdx.x"):
+ with T.block("merge"):
+ s_val = _var("float32")
+ s_other_val = _var("float32")
+ s_max = _var("float32")
+ scale = _var("float32")
+ other_scale = _var("float32")
+
+ v_vec = T.alloc_buffer((VEC_SIZE,), v_dtype,
scope="local")
+ v_other_vec = T.alloc_buffer((VEC_SIZE,), v_dtype,
scope="local")
+
+ s_val[0] = S[bx, ty + by * bdy]
+ s_other_val[0] = S_other[bx, ty + by * bdy]
+ s_max[0] = T.max(s_val[0], s_other_val[0])
+ s_val[0] = T.exp2(s_val[0] - s_max[0])
+ s_other_val[0] = T.exp2(s_other_val[0] - s_max[0])
+ scale[0] = s_val[0] / (s_val[0] + s_other_val[0])
+ other_scale[0] = s_other_val[0] / (s_val[0] +
s_other_val[0])
+
+ # load v
+ for vec in T.vectorized(VEC_SIZE):
+ v_vec[vec] = V[bx, ty + by * bdy, tx *
VEC_SIZE + vec]
+ # load v_other
+ for vec in T.vectorized(VEC_SIZE):
+ v_other_vec[vec] = V_other[bx, ty + by * bdy,
tx * VEC_SIZE + vec]
+
+ # merge
+ for vec in T.serial(VEC_SIZE):
+ v_vec[vec] = (
+ v_vec[vec] * scale[0] + v_other_vec[vec] *
other_scale[0]
+ )
+
+ # store v
+ for vec in T.vectorized(VEC_SIZE):
+ V[bx, ty + by * bdy, tx * VEC_SIZE + vec] =
v_vec[vec]
+
+ # store s
+ S[bx, ty + by * bdy] = T.log2(s_val[0] +
s_other_val[0]) + s_max[0]
# pylint: enable=invalid-name
return merge_state_inplace