This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 462eeb72b8 [WebLLM] Replace int64s with int32s in WebGPU kernels 
(#18361)
462eeb72b8 is described below

commit 462eeb72b88e1a16412502b7db9f224b8b267590
Author: akaashrp <[email protected]>
AuthorDate: Wed Oct 22 08:43:11 2025 +0530

    [WebLLM] Replace int64s with int32s in WebGPU kernels (#18361)
    
    This PR replaces int64s with int32s in the argsort and 
parallel_sampling_from_prob
    kernels when the target is WebGPU (since WGSL does not currently support 
i64)
---
 python/tvm/relax/backend/gpu_generic/sampling.py   |  9 +++-
 python/tvm/topi/gpu/sort.py                        | 49 ++++++++++++++++------
 .../python/relax/test_backend_dispatch_sampling.py |  2 +-
 3 files changed, 45 insertions(+), 15 deletions(-)

diff --git a/python/tvm/relax/backend/gpu_generic/sampling.py 
b/python/tvm/relax/backend/gpu_generic/sampling.py
index 2634a07427..9a0d01ef23 100644
--- a/python/tvm/relax/backend/gpu_generic/sampling.py
+++ b/python/tvm/relax/backend/gpu_generic/sampling.py
@@ -19,6 +19,7 @@
 
 import math
 from typing import Callable, Optional
+import tvm
 from tvm.script import tir as T
 from tvm.tir import PrimFunc
 
@@ -69,6 +70,9 @@ def gpu_multinomial_from_uniform(
         The generated function
     """
 
+    target = tvm.target.Target.current()
+    target_dtype = "int32" if "webgpu" in str(target) else "int64"
+
     TX = T.int64(tx_len)  # threadIdx.x
     TY = T.int64(ty_len)  # threadIdx.y
 
@@ -282,7 +286,8 @@ def gpu_multinomial_from_uniform(
                     # at least one iteration
                     while T.tvm_thread_invariant(
                         (step_iter[()] == 0 or aggregate[()] < u - eps)
-                        and T.Cast("int64", step_iter[()]) < 
T.ceildiv(vocab_size, block_elem)
+                        and T.Cast(target_dtype, step_iter[()])
+                        < T.Cast(target_dtype, T.ceildiv(vocab_size, 
block_elem))
                     ):
                         single_batch_sampling(
                             prob,
@@ -290,7 +295,7 @@ def gpu_multinomial_from_uniform(
                             vocab_size,
                             ty,
                             tx,
-                            T.Cast("int64", step_iter[()]),
+                            T.Cast(target_dtype, step_iter[()]),
                             0.0,
                             aggregate,
                             u,
diff --git a/python/tvm/topi/gpu/sort.py b/python/tvm/topi/gpu/sort.py
index eb48da0a02..807b23a956 100644
--- a/python/tvm/topi/gpu/sort.py
+++ b/python/tvm/topi/gpu/sort.py
@@ -219,11 +219,22 @@ def _sort_common(
     upper_lim = ceil_log2(size)
 
     def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart, 
diag, step_count):
-        first = ib.allocate("int64", (1,), name="first", scope="local")
-        mid = ib.allocate("int64", (1,), name="mid", scope="local")
-        last = ib.allocate("int64", (1,), name="last", scope="local")
-        first[0] = tvm.te.max(0, diag - bCount)
-        last[0] = tvm.te.min(diag, aCount)
+        target = tvm.target.Target.current()
+        is_webgpu = "webgpu" in str(target)
+        target_dtype = "int32" if is_webgpu else "int64"
+
+        first = ib.allocate(target_dtype, (1,), name="first", scope="local")
+        mid = ib.allocate(target_dtype, (1,), name="mid", scope="local")
+        last = ib.allocate(target_dtype, (1,), name="last", scope="local")
+        max_val = tvm.te.max(0, diag - bCount)
+        min_val = tvm.te.min(diag, aCount)
+        if is_webgpu:
+            first[0] = cast(max_val, target_dtype)
+            last[0] = cast(min_val, target_dtype)
+        else:
+            first[0] = max_val
+            last[0] = min_val
+
         with ib.while_loop(first[0] < last[0]):
             mid = (first[0] + last[0]) >> 1
             a = source[base_idx + (aStart + mid)]
@@ -250,10 +261,20 @@ def _sort_common(
         first,
         last,
     ):
-        i = ib.allocate("int64", (1,), name="i", scope="local")
-        j = ib.allocate("int64", (1,), name="j", scope="local")
-        i[0] = aStart + first
-        j[0] = bStart + diag - last
+        target = tvm.target.Target.current()
+        is_webgpu = "webgpu" in str(target)
+        target_dtype = "int32" if is_webgpu else "int64"
+        i = ib.allocate(target_dtype, (1,), name="i", scope="local")
+        j = ib.allocate(target_dtype, (1,), name="j", scope="local")
+        i_val = aStart + first
+        j_val = bStart + diag - last
+        if is_webgpu:
+            i[0] = cast(i_val, target_dtype)
+            j[0] = cast(j_val, target_dtype)
+        else:
+            i[0] = i_val
+            j[0] = j_val
+
         with ib.for_range(0, tvm.te.min(aCount + bCount - diag, step_count)) 
as count:
             i_idx = base_idx + i[0]
             j_idx = base_idx + j[0]
@@ -287,7 +308,9 @@ def _sort_common(
                 with ib.else_scope():
                     assign_j()
 
-    with ib.for_range(0, cast(upper_lim - lower_lim, "int64"), dtype="int64") 
as l2_width:
+    target = tvm.target.Target.current()
+    target_dtype = "int32" if "webgpu" in str(target) else "int64"
+    with ib.for_range(0, cast(upper_lim - lower_lim, target_dtype), 
dtype=target_dtype) as l2_width:
         width = 2 << (l2_width + lower_lim)
         # Define and launch the cuda kernel
         with ib.new_scope():
@@ -359,8 +382,10 @@ def _sort_common(
             def mergesort(source, dest, source_idx, dest_idx, size, width, 
even):
                 # calculate the start, mid, and end points of this section
                 start = width * bz
-                middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), 
size), "int64")
-                end = cast(tvm.te.min(start + width, size), "int64")
+                target = tvm.target.Target.current()
+                target_dtype = "int32" if "webgpu" in str(target) else "int64"
+                middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2), 
size), target_dtype)
+                end = cast(tvm.te.min(start + width, size), target_dtype)
                 with ib.if_scope(start < size):
                     with ib.if_scope(nbx == 1):
                         ## merge the start->middle and middle->end arrays
diff --git a/tests/python/relax/test_backend_dispatch_sampling.py 
b/tests/python/relax/test_backend_dispatch_sampling.py
index de31efc3fa..fb36f87775 100644
--- a/tests/python/relax/test_backend_dispatch_sampling.py
+++ b/tests/python/relax/test_backend_dispatch_sampling.py
@@ -103,7 +103,7 @@ def test_dispatch_multinomial_from_uniform_gpu():
                         u: T.float32 = uniform_samples[bx, 0]
                         aggregate[()] = T.Cast("float32", 0)
                         step_iter[()] = 0
-                        while T.tvm_thread_invariant((step_iter[()] == 0 or 
aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", 
step_iter[()]) < (vocab_size + T.int64(512) - T.int64(1)) // T.int64(512)):
+                        while T.tvm_thread_invariant((step_iter[()] == 0 or 
aggregate[()] < u - T.float32(9.9999999999999995e-07)) and T.Cast("int64", 
step_iter[()]) < T.Cast("int64", (vocab_size + T.int64(512) - T.int64(1)) // 
T.int64(512))):
                             with T.block(""):
                                 T.reads(step_iter[()], prob[row_idx, 
T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + tx * 
T.int64(4):T.Cast("int64", step_iter[()]) * T.int64(512) + ty * T.int64(128) + 
tx * T.int64(4) + T.int64(4)], aggregate[()])
                                 T.writes(sample_id_local[()], aggregate[()])

Reply via email to