gemini-code-assist[bot] commented on code in PR #18361:
URL: https://github.com/apache/tvm/pull/18361#discussion_r2412367199
##########
python/tvm/relax/backend/gpu_generic/sampling.py:
##########
@@ -282,15 +283,15 @@ def parallel_sampling_from_prob(
# at least one iteration
while T.tvm_thread_invariant(
(step_iter[()] == 0 or aggregate[()] < u - eps)
- and T.Cast("int64", step_iter[()]) <
T.ceildiv(vocab_size, block_elem)
+ and T.Cast("int32" if "webgpu" in
str(tvm.target.Target.current()) else "int64", step_iter[()]) < T.Cast("int32"
if "webgpu" in str(tvm.target.Target.current()) else "int64",
T.ceildiv(vocab_size, block_elem))
Review Comment:

The expression `"int32" if "webgpu" in str(tvm.target.Target.current()) else
"int64"` is used multiple times in this function (here and on line 294). This
harms readability and maintainability. Consider defining a Python variable for
this string outside of the `@T.prim_func` definition, inside
`gpu_multinomial_from_uniform`, and using that variable here. This would make
the code cleaner and easier to modify.
##########
python/tvm/topi/gpu/sort.py:
##########
@@ -359,8 +376,13 @@ def merge(source, dest, source_idx, dest_idx):
def mergesort(source, dest, source_idx, dest_idx, size, width,
even):
# calculate the start, mid, and end points of this section
start = width * bz
- middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2),
size), "int64")
- end = cast(tvm.te.min(start + width, size), "int64")
+ target = tvm.target.Target.current()
+ if "webgpu" in str(target):
+ middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width,
2), size), "int32")
+ end = cast(tvm.te.min(start + width, size), "int32")
+ else:
+ middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width,
2), size), "int64")
+ end = cast(tvm.te.min(start + width, size), "int64")
Review Comment:

This conditional logic can be simplified by defining the `dtype` based on
the target and then using it for casting. This avoids repeating the `cast` and
`tvm.te.min` calls.
```python
target = tvm.target.Target.current()
dtype = "int32" if "webgpu" in str(target) else "int64"
middle = cast(tvm.te.min(start + tvm.tir.indexdiv(width, 2),
size), dtype)
end = cast(tvm.te.min(start + width, size), dtype)
```
##########
python/tvm/topi/gpu/sort.py:
##########
@@ -250,10 +258,17 @@ def serial_merge(
first,
last,
):
- i = ib.allocate("int64", (1,), name="i", scope="local")
- j = ib.allocate("int64", (1,), name="j", scope="local")
- i[0] = aStart + first
- j[0] = bStart + diag - last
+ target = tvm.target.Target.current()
+ if "webgpu" in str(target):
+ i = ib.allocate("int32", (1,), name="i", scope="local")
+ j = ib.allocate("int32", (1,), name="j", scope="local")
+ i[0] = cast(aStart + first, "int32")
+ j[0] = cast(bStart + diag - last, "int32")
+ else:
+ i = ib.allocate("int64", (1,), name="i", scope="local")
+ j = ib.allocate("int64", (1,), name="j", scope="local")
+ i[0] = aStart + first
+ j[0] = bStart + diag - last
Review Comment:

Similar to the previous comment, this block has duplicated code. You can
refactor it by determining the `dtype` first and then using it.
```suggestion
target = tvm.target.Target.current()
is_webgpu = "webgpu" in str(target)
dtype = "int32" if is_webgpu else "int64"
i = ib.allocate(dtype, (1,), name="i", scope="local")
j = ib.allocate(dtype, (1,), name="j", scope="local")
i_val = aStart + first
j_val = bStart + diag - last
if is_webgpu:
i[0] = cast(i_val, dtype)
j[0] = cast(j_val, dtype)
else:
i[0] = i_val
j[0] = j_val
```
##########
python/tvm/topi/gpu/sort.py:
##########
@@ -219,11 +219,19 @@ def compare(a, b):
upper_lim = ceil_log2(size)
def get_merge_begin(source, base_idx, aCount, bCount, aStart, bStart,
diag, step_count):
- first = ib.allocate("int64", (1,), name="first", scope="local")
- mid = ib.allocate("int64", (1,), name="mid", scope="local")
- last = ib.allocate("int64", (1,), name="last", scope="local")
- first[0] = tvm.te.max(0, diag - bCount)
- last[0] = tvm.te.min(diag, aCount)
+ target = tvm.target.Target.current()
+ if "webgpu" in str(target):
+ first = ib.allocate("int32", (1,), name="first", scope="local")
+ mid = ib.allocate("int32", (1,), name="mid", scope="local")
+ last = ib.allocate("int32", (1,), name="last", scope="local")
+ first[0] = cast(tvm.te.max(0, diag - bCount), "int32")
+ last[0] = cast(tvm.te.min(diag, aCount), "int32")
+ else:
+ first = ib.allocate("int64", (1,), name="first", scope="local")
+ mid = ib.allocate("int64", (1,), name="mid", scope="local")
+ last = ib.allocate("int64", (1,), name="last", scope="local")
+ first[0] = tvm.te.max(0, diag - bCount)
+ last[0] = tvm.te.min(diag, aCount)
Review Comment:

There's quite a bit of duplicated code between the `if` and `else` branches.
You can refactor this to improve readability and reduce redundancy by defining
the `dtype` first and then using it for allocation and casting. This also makes
the logic clearer.
```suggestion
target = tvm.target.Target.current()
is_webgpu = "webgpu" in str(target)
dtype = "int32" if is_webgpu else "int64"
first = ib.allocate(dtype, (1,), name="first", scope="local")
mid = ib.allocate(dtype, (1,), name="mid", scope="local")
last = ib.allocate(dtype, (1,), name="last", scope="local")
max_val = tvm.te.max(0, diag - bCount)
min_val = tvm.te.min(diag, aCount)
if is_webgpu:
first[0] = cast(max_val, dtype)
last[0] = cast(min_val, dtype)
else:
first[0] = max_val
last[0] = min_val
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]