This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch builder-update
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/builder-update by this push:
     new bfdf325a4c fix
bfdf325a4c is described below

commit bfdf325a4cfd89e5cb5d7e5c84a9b51e7f2e8375
Author: tqchen <[email protected]>
AuthorDate: Thu Feb 5 15:33:47 2026 -0500

    fix
---
 python/tvm/relax/transform/legalize_ops/grad.py | 49 ++++++++++++++-----------
 python/tvm/script/ir_builder/tir/utils.py       | 12 ++++--
 python/tvm/topi/gpu/sort.py                     |  3 +-
 3 files changed, 37 insertions(+), 27 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/grad.py 
b/python/tvm/relax/transform/legalize_ops/grad.py
index 50a97ab181..016350fa41 100644
--- a/python/tvm/relax/transform/legalize_ops/grad.py
+++ b/python/tvm/relax/transform/legalize_ops/grad.py
@@ -21,6 +21,7 @@ import logging
 from tvm import te, tir, topi
 from tvm.script.ir_builder import IRBuilder
 from tvm.script.ir_builder import tir as T
+from tvm.script.ir_builder.tir.utils import buffer_proxy
 from ...block_builder import BlockBuilder
 from ...expr import Call, Expr
 from .common import register_legalize
@@ -162,17 +163,26 @@ def _grad_take_backward(bb: BlockBuilder, call: Call) -> 
Expr:
     def te_take_backward(output_grad, x, indices):
         def gen_ir(output_grad_ptr, x_ptr, indices_ptr, out_ptr):
             # pylint: disable=invalid-name
-            with IRBuilder() as ib:
-                fused_shape = 1
-                for i in x_ptr.shape:
-                    fused_shape *= i
+            # Use buffer_proxy for flat indexing on multi-dimensional buffers
+            out = buffer_proxy(out_ptr)
+            grad = buffer_proxy(output_grad_ptr)
+            idx = buffer_proxy(indices_ptr)
+
+            fused_shape = 1
+            for i in x_ptr.shape:
+                fused_shape *= i
 
+            # Build init loop (zero-fill output buffer)
+            with IRBuilder() as ib:
                 with T.serial(fused_shape) as i:
-                    T.buffer_store(out_ptr, tir.const(0, dtype=x_ptr.dtype), 
[i])
+                    out[i] = tir.const(0, dtype=x_ptr.dtype)
+                init_stmt = ib.get()
 
-                assert len(indices_ptr.shape) == 1  # indices in take must be 
1-dim Tensor
-                indices_len = indices_ptr.shape[0]
+            assert len(indices_ptr.shape) == 1  # indices in take must be 
1-dim Tensor
+            indices_len = indices_ptr.shape[0]
 
+            # Build accumulation loop
+            with IRBuilder() as ib:
                 if axis is not None:
                     fused_output_grad_shape_pre = 1
                     fused_output_grad_shape_nxt = 1
@@ -184,34 +194,29 @@ def _grad_take_backward(bb: BlockBuilder, call: Call) -> 
Expr:
 
                     x_axis_len = x_ptr.shape[axis]
 
-                    with T.parallel(
+                    with T.serial(
                         fused_output_grad_shape_pre * 
fused_output_grad_shape_nxt
                     ) as fused:
                         i = fused // fused_output_grad_shape_nxt
                         j = fused % fused_output_grad_shape_nxt
-                        with T.serial(indices_len) as l:
+                        with T.serial(indices_len) as loop_l:
                             out_idx = (
                                 i * fused_output_grad_shape_nxt * x_axis_len
-                                + indices_ptr[l] * fused_output_grad_shape_nxt
+                                + idx[loop_l] * fused_output_grad_shape_nxt
                                 + j
                             )
                             grad_idx = (
                                 i * fused_output_grad_shape_nxt * indices_len
-                                + l * fused_output_grad_shape_nxt
+                                + loop_l * fused_output_grad_shape_nxt
                                 + j
                             )
-                            T.buffer_store(
-                                out_ptr, out_ptr[out_idx] + 
output_grad_ptr[grad_idx], [out_idx]
-                            )
+                            out[out_idx] = out[out_idx] + grad[grad_idx]
                 else:
-                    with T.serial(indices_len) as l:
-                        T.buffer_store(
-                            out_ptr,
-                            out_ptr[indices_ptr[l]] + output_grad_ptr[l],
-                            [indices_ptr[l]],
-                        )
-
-                return ib.get()
+                    with T.serial(indices_len) as loop_l:
+                        out[idx[loop_l]] = out[idx[loop_l]] + grad[loop_l]
+                accum_stmt = ib.get()
+
+            return tir.SeqStmt([init_stmt, accum_stmt])
 
         shape = x.shape
         out_buf = tir.decl_buffer(shape, x.dtype, "out_buf")
diff --git a/python/tvm/script/ir_builder/tir/utils.py 
b/python/tvm/script/ir_builder/tir/utils.py
index 37aaf98ee7..fcbf34c61a 100644
--- a/python/tvm/script/ir_builder/tir/utils.py
+++ b/python/tvm/script/ir_builder/tir/utils.py
@@ -17,7 +17,7 @@
 """Utility helpers for TIR IRBuilder."""
 
 import contextlib
-from typing import List, Union
+from typing import List
 
 from tvm import tir
 from tvm.tir import Buffer
@@ -167,9 +167,13 @@ def _unravel_index(index, shape):
         The multi-dimensional indices.
     """
     indices = []
-    for dim in reversed(shape):
-        indices.append(index % dim)
-        index = index // dim
+    for i, dim in enumerate(reversed(shape)):
+        if i == len(shape) - 1:
+            # Outermost dimension: use remaining quotient directly (no modulo)
+            indices.append(index)
+        else:
+            indices.append(index % dim)
+            index = index // dim
     return list(reversed(indices))
 
 
diff --git a/python/tvm/topi/gpu/sort.py b/python/tvm/topi/gpu/sort.py
index ed6b4d3087..b4bd1a4438 100644
--- a/python/tvm/topi/gpu/sort.py
+++ b/python/tvm/topi/gpu/sort.py
@@ -420,7 +420,8 @@ def _sort_common(
                     get_merge_begin(
                         dest, base_idx, aCount, bCount, aStart, bStart, diag, 
first, last
                     )
-                    serial_merge(
+                    # Intentionally swap source/dest for reverse direction 
merge
+                    serial_merge(  # pylint: disable=arguments-out-of-order
                         dest,
                         source,
                         dest_idx,

Reply via email to