This is an automated email from the ASF dual-hosted git repository.
ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 462eeb72b8 [WebLLM] Replace int64s with int32s in WebGPU kernels
(#18361)
462eeb72b8 is described below
commit 462eeb72b88e1a16412502b7db9f224b8b267590
Author: akaashrp <[email protected]>
AuthorDate: Wed Oct 22 08:43:11 2025 +0530
[WebLLM] Replace int64s with int32s in WebGPU kernels (#18361)
This PR replaces int64s with int32s in the argsort and
parallel_sampling_from_prob
kernels when the target is WebGPU (since WGSL does not currently support
i64)
---
python/tvm/relax/backend/gpu_generic/sampling.py | 9 +++-
python/tvm/topi/gpu/sort.py | 49 ++++++++++++++++------
.../python/relax/test_backend_dispatch_sampling.py | 2 +-
3 files changed, 45 insertions(+), 15 deletions(-)
diff --git a/python/tvm/relax/backend/gpu_generic/sampling.py
b/python/tvm/relax/backend/gpu_generic/sampling.py
index 2634a07427..9a0d01ef23 100644
--- a/python/tvm/relax/backend/gpu_generic/sampling.py
+++ b/python/tvm/relax/backend/gpu_generic/sampling.py
@@ -19,6 +19,7 @@
import math
from typing import Callable, Optional
+import tvm
from tvm.script import tir as T
from tvm.tir import PrimFunc
@@ -69,6 +70,9 @@ def gpu_multinomial_from_uniform(
The generated function
"""
+ target = tvm.target.Target.current()
+ target_dtype = "int32" if "webgpu" in str(target) else "int64"
+
TX = T.int64(tx_len) # threadIdx.x
TY = T.int64(ty_len) # threadIdx.y
@@ -282,7 +286,8 @@ def gpu_multinomial_from_uniform(
# at least one iteration
while T.tvm_thread_invariant(
(step_iter[()] == 0 or aggregate[()] < u - eps)
- and T.Cast("int64", step_iter[()]) <
T.ceildiv(vocab_size, block_elem)
+ and T.Cast(target_dtype, step_iter[()])
+ < T.Cast(target_dtype, T.ceildiv(vocab_size,
block_elem))
):
single_batch_sampling(
prob,
@@ -290,7 +295,7 @@ def gpu_multinomial_from_uniform(
vocab_size,
ty,
tx,
- T.Cast("int64", step_iter[()]),
+ T.Cast(target_dtype, step_iter[()]),
0.0,
aggregate,
u,
diff --git a/python/tvm/topi/gpu/sort.py b/python/tvm/topi/gpu/sort.py
index eb48da0a02..807b23a956 100644
--- a/python/tvm/topi/gpu/sort.py
+++ b/python/tvm/topi/gpu/sort.py
@@ -219,11 +219,22 @@ def _sort_common(
upper_lim = ceil_log2(size)
def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart,
diag, step_count):
- first = ib.allocate("int64", (1,), name="first", scope="local")
- mid = ib.allocate("int64", (1,), name="mid", scope="local")
- last = ib.allocate("int64", (1,), name="last", scope="local")
- first[0] = tvm.te.max(0, diag - bCount)
- last[0] = tvm.te.min(diag, aCount)
+ target = tvm.target.Target.current()
+ is_webgpu = "webgpu" in str(target)
+ target_dtype = "int32" if is_webgpu else "int64"
+
+ first = ib.allocate(target_dtype, (1,), name="first", scope="local")
+ mid = ib.allocate(target_dtype, (1,), name="mid", scope="local")
+ last = ib.allocate(target_dtype, (1,), name="last", scope="local")
+ max_val = tvm.te.max(0, diag - bCount)
+ min_val = tvm.te.min(diag, aCount)
+ if is_webgpu:
+ first[0] = cast(max_val, target_dtype)
+ last[0] = cast(min_val, target_dtype)
+ else:
+ first[0] = max_val
+ last[0] = min_val
+
with ib.while_loop(first[0] < last[0]):
mid = (first[0] + last[0]) >> 1
a = source[base_idx + (aStart + mid)]
@@ -250,10 +261,20 @@ def _sort_common(
first,
last,
):
- i = ib.allocate("int64", (1,), name="i", scope="local")
- j = ib.allocate("int64", (1,), name="j", scope="local")
- i[0] = aStart + first
- j[0] = bStart + diag - last
+ target = tvm.target.Target.current()
+ is_webgpu = "webgpu" in str(target)
+ target_dtype = "int32" if is_webgpu else "int64"
+ i = ib.allocate(target_dtype, (1,), name="i", scope="local")
+ j = ib.allocate(target_dtype, (1,), name="j", scope="local")
+ i_val = aStart + first
+ j_val = bStart + diag - last
+ if is_webgpu:
+ i[0] = cast(i_val, target_dtype)
+ j[0] = cast(j_val, target_dtype)
+ else:
+ i[0] = i_val
+ j[0] = j_val
+
with ib.for_range(0, tvm.te.min(aCount + bCount - diag, step_count))
as count:
i_idx = base_idx + i[0]
j_idx = base_idx + j[0]
@@ -287,7 +308,9 @@ def _sort_common(
with ib.else_scope():
assign_j()
- with ib.for_range(0, cast(upper_lim - lower_lim, "int64"), dtype="int64")
as l2_width:
+ target = tvm.target.Target.current()
+ target_dtype = "int32" if "webgpu" in str(target) else "int64"
+ with ib.for_range(0, cast(upper_lim - lower_lim, target_dtype),
dtype=target_dtype) as l2_width:
width = 2 << (l2_width + lower_lim)
# Define and launch the cuda kernel
with ib.new_scope():
@@ -359,8 +382,10 @@ def _sort_common(
def mergesort(source, dest, source_idx, dest_idx, size, width,
even):
# calculate the start, mid, and end points of this section
start = width * bz
- middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2),
size), "int64")
- end = cast(tvm.te.min(start + width, size), "int64")
+ target = tvm.target.Target.current()
+ target_dtype = "int32" if "webgpu" in str(target) else "int64"
+ middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2),
size), target_dtype)
+ end = cast(tvm.te.min(start + width, size), target_dtype)
with ib.if_scope(start < size):
with ib.if_scope(nbx == 1):
## merge the start->middle and middle->end arrays
diff --git a/tests/python/relax/test_backend_dispatch_sampling.py
b/tests/python/relax/test_backend_dispatch_sampling.py
index de31efc3fa..fb36f87775 100644
--- a/tests/python/relax/test_backend_dispatch_sampling.py
+++ b/tests/python/relax/test_backend_dispatch_sampling.py
@@ -103,7 +103,7 @@ def test_dispatch_multinomial_from_uniform_gpu():
u: T.float32 = uniform_samples[bx, 0]
aggregate[()] = T.Cast("float32", 0)
step_iter[()] = 0
- while T.tvm_thread_invariant((step_iter[()] == 0 or
aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64",
step_iter[()]) < (vocab_size + T.int64(512) - T.int64(1)) // T.int64(512)):
+ while T.tvm_thread_invariant((step_iter[()] == 0 or
aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64",
step_iter[()]) < T.Cast("int64", (vocab_size + T.int64(512) - T.int64(1)) //
T.int64(512))):
with T.block(""):
T.reads(step_iter[()], prob[row_idx,
T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx *
T.int64(4):T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) +
tx * T.int64(4) + T.int64(4)], aggregate[()])
T.writes(sample_id_local[()], aggregate[()])