This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch builder-update
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/builder-update by this push:
new bfdf325a4c fix
bfdf325a4c is described below
commit bfdf325a4cfd89e5cb5d7e5c84a9b51e7f2e8375
Author: tqchen <[email protected]>
AuthorDate: Thu Feb 5 15:33:47 2026 -0500
fix
---
python/tvm/relax/transform/legalize_ops/grad.py | 49 ++++++++++++++-----------
python/tvm/script/ir_builder/tir/utils.py | 12 ++++--
python/tvm/topi/gpu/sort.py | 3 +-
3 files changed, 37 insertions(+), 27 deletions(-)
diff --git a/python/tvm/relax/transform/legalize_ops/grad.py
b/python/tvm/relax/transform/legalize_ops/grad.py
index 50a97ab181..016350fa41 100644
--- a/python/tvm/relax/transform/legalize_ops/grad.py
+++ b/python/tvm/relax/transform/legalize_ops/grad.py
@@ -21,6 +21,7 @@ import logging
from tvm import te, tir, topi
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import tir as T
+from tvm.script.ir_builder.tir.utils import buffer_proxy
from ...block_builder import BlockBuilder
from ...expr import Call, Expr
from .common import register_legalize
@@ -162,17 +163,26 @@ def _grad_take_backward(bb: BlockBuilder, call: Call) ->
Expr:
def te_take_backward(output_grad, x, indices):
def gen_ir(output_grad_ptr, x_ptr, indices_ptr, out_ptr):
# pylint: disable=invalid-name
- with IRBuilder() as ib:
- fused_shape = 1
- for i in x_ptr.shape:
- fused_shape *= i
+ # Use buffer_proxy for flat indexing on multi-dimensional buffers
+ out = buffer_proxy(out_ptr)
+ grad = buffer_proxy(output_grad_ptr)
+ idx = buffer_proxy(indices_ptr)
+
+ fused_shape = 1
+ for i in x_ptr.shape:
+ fused_shape *= i
+ # Build init loop (zero-fill output buffer)
+ with IRBuilder() as ib:
with T.serial(fused_shape) as i:
- T.buffer_store(out_ptr, tir.const(0, dtype=x_ptr.dtype),
[i])
+ out[i] = tir.const(0, dtype=x_ptr.dtype)
+ init_stmt = ib.get()
- assert len(indices_ptr.shape) == 1 # indices in take must be
1-dim Tensor
- indices_len = indices_ptr.shape[0]
+ assert len(indices_ptr.shape) == 1 # indices in take must be
1-dim Tensor
+ indices_len = indices_ptr.shape[0]
+ # Build accumulation loop
+ with IRBuilder() as ib:
if axis is not None:
fused_output_grad_shape_pre = 1
fused_output_grad_shape_nxt = 1
@@ -184,34 +194,29 @@ def _grad_take_backward(bb: BlockBuilder, call: Call) ->
Expr:
x_axis_len = x_ptr.shape[axis]
- with T.parallel(
+ with T.serial(
fused_output_grad_shape_pre *
fused_output_grad_shape_nxt
) as fused:
i = fused // fused_output_grad_shape_nxt
j = fused % fused_output_grad_shape_nxt
- with T.serial(indices_len) as l:
+ with T.serial(indices_len) as loop_l:
out_idx = (
i * fused_output_grad_shape_nxt * x_axis_len
- + indices_ptr[l] * fused_output_grad_shape_nxt
+ + idx[loop_l] * fused_output_grad_shape_nxt
+ j
)
grad_idx = (
i * fused_output_grad_shape_nxt * indices_len
- + l * fused_output_grad_shape_nxt
+ + loop_l * fused_output_grad_shape_nxt
+ j
)
- T.buffer_store(
- out_ptr, out_ptr[out_idx] +
output_grad_ptr[grad_idx], [out_idx]
- )
+ out[out_idx] = out[out_idx] + grad[grad_idx]
else:
- with T.serial(indices_len) as l:
- T.buffer_store(
- out_ptr,
- out_ptr[indices_ptr[l]] + output_grad_ptr[l],
- [indices_ptr[l]],
- )
-
- return ib.get()
+ with T.serial(indices_len) as loop_l:
+ out[idx[loop_l]] = out[idx[loop_l]] + grad[loop_l]
+ accum_stmt = ib.get()
+
+ return tir.SeqStmt([init_stmt, accum_stmt])
shape = x.shape
out_buf = tir.decl_buffer(shape, x.dtype, "out_buf")
diff --git a/python/tvm/script/ir_builder/tir/utils.py
b/python/tvm/script/ir_builder/tir/utils.py
index 37aaf98ee7..fcbf34c61a 100644
--- a/python/tvm/script/ir_builder/tir/utils.py
+++ b/python/tvm/script/ir_builder/tir/utils.py
@@ -17,7 +17,7 @@
"""Utility helpers for TIR IRBuilder."""
import contextlib
-from typing import List, Union
+from typing import List
from tvm import tir
from tvm.tir import Buffer
@@ -167,9 +167,13 @@ def _unravel_index(index, shape):
The multi-dimensional indices.
"""
indices = []
- for dim in reversed(shape):
- indices.append(index % dim)
- index = index // dim
+ for i, dim in enumerate(reversed(shape)):
+ if i == len(shape) - 1:
+ # Outermost dimension: use remaining quotient directly (no modulo)
+ indices.append(index)
+ else:
+ indices.append(index % dim)
+ index = index // dim
return list(reversed(indices))
diff --git a/python/tvm/topi/gpu/sort.py b/python/tvm/topi/gpu/sort.py
index ed6b4d3087..b4bd1a4438 100644
--- a/python/tvm/topi/gpu/sort.py
+++ b/python/tvm/topi/gpu/sort.py
@@ -420,7 +420,8 @@ def _sort_common(
get_merge_begin(
dest, base_idx, aCount, bCount, aStart, bStart, diag,
first, last
)
- serial_merge(
+ # Intentionally swap source/dest for reverse direction
merge
+ serial_merge( # pylint: disable=arguments-out-of-order
dest,
source,
dest_idx,