This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fde09d2052 [BugFix][Relax] Fix scatter_elements and scatter_nd CUDA 
compilation (#19497)
fde09d2052 is described below

commit fde09d2052d1b7238d6ac61f0b083baf64d7c098
Author: as4230 <[email protected]>
AuthorDate: Mon May 4 04:30:00 2026 -0400

    [BugFix][Relax] Fix scatter_elements and scatter_nd CUDA compilation 
(#19497)
    
    `topi.scatter_elements` and `topi.scatter_nd` emit bare `T.parallel`
    loops in their te.extern IRBuilder bodies which trips `VerifyMemory` on
    CUDA targets:
    
        RuntimeError: Memory verification failed
        ...
        Did you forget to bind?
    
    CPU (LLVM) is unaffected.
    
    This fix makes the IRBuilder body in both `topi/scatter_elements.py` and
    `topi/scatter.py` target-aware. When `Target.current()` is a GPU target
    it emits thread bindings instead of `T.parallel`.
    
    Fixes #19451.
---
 .../tvm/relax/transform/legalize_ops/manipulate.py |  12 +-
 python/tvm/topi/gpu/__init__.py                    |   2 +
 python/tvm/topi/gpu/scatter_elements.py            | 162 +++++++++++++++++++++
 python/tvm/topi/gpu/scatter_nd.py                  | 129 ++++++++++++++++
 .../test_transform_legalize_ops_manipulate.py      |  46 ++++++
 5 files changed, 349 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py 
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 2a1d249ef7..fc7ee0d12e 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -235,10 +235,16 @@ def _meshgrid(bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+def _is_gpu_target():
+    target = tvm.target.Target.current(allow_none=True)
+    return target is not None and "gpu" in target.keys
+
+
 @register_legalize("relax.scatter_elements")
 def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
+    te_func = topi.gpu.scatter_elements if _is_gpu_target() else 
topi.scatter_elements
     return bb.call_te(
-        topi.scatter_elements,
+        te_func,
         call.args[0],
         call.args[1],
         call.args[2],
@@ -250,10 +256,12 @@ def _scatter_elements(bb: BlockBuilder, call: Call) -> 
Expr:
 @register_legalize("relax.scatter_nd")
 def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr:
     # TODO(relax-team): Support native scatter_nd without te extern
+    base_te = topi.gpu.scatter_nd if _is_gpu_target() else topi.scatter_nd
+
     def scatter_nd(data, indices, updates, reduction):
         axes = list(range(len(indices.shape)))
         indices = topi.transpose(indices, axes[-1:] + axes[:-1])
-        return topi.scatter_nd(data, indices, updates, reduction)
+        return base_te(data, indices, updates, reduction)
 
     return bb.call_te(
         scatter_nd,
diff --git a/python/tvm/topi/gpu/__init__.py b/python/tvm/topi/gpu/__init__.py
index e56a1d7123..69998957f3 100644
--- a/python/tvm/topi/gpu/__init__.py
+++ b/python/tvm/topi/gpu/__init__.py
@@ -20,4 +20,6 @@
 """GPU specific declaration."""
 
 from .scan import cumsum, cumprod
+from .scatter_elements import scatter_elements
+from .scatter_nd import scatter_nd
 from .sort import *
diff --git a/python/tvm/topi/gpu/scatter_elements.py 
b/python/tvm/topi/gpu/scatter_elements.py
new file mode 100644
index 0000000000..a7d9421862
--- /dev/null
+++ b/python/tvm/topi/gpu/scatter_elements.py
@@ -0,0 +1,162 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""scatter_elements related operators"""
+
+import tvm
+from tvm import te, tirx
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder import tirx as T
+
+from .. import utils
+from ..math import cast
+from ..utils import ceil_div
+
+
+def scatter_elements(data, indices, updates, axis=0, reduction="update"):
+    """GPU implementation of scatter_elements with explicit thread bindings"""
+    if not isinstance(axis, int):
+        axis = utils.get_const_int(axis)
+
+    # Prepare ranges and strides
+    shape = data.shape
+    if axis < 0:
+        axis = len(shape) + axis
+    axis_range = cast(shape[axis], indices.dtype)
+
+    full_range = 1
+    after_axis_range = 1
+    for i, value in enumerate(shape, 0):
+        full_range *= value
+        if i > axis:
+            after_axis_range *= value
+    before_axis_stride = axis_range * after_axis_range
+
+    ind_shape = indices.shape
+    ind_axis_range = ind_shape[axis]
+
+    ind_before_axis_range = 1
+    ind_after_axis_range = 1
+    for i, value in enumerate(ind_shape, 0):
+        if i < axis:
+            ind_before_axis_range *= value
+        elif i > axis:
+            ind_after_axis_range *= value
+    ind_before_axis_stride = ind_axis_range * ind_after_axis_range
+    ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range
+
+    def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr, reduce_func):
+        # pylint: disable=invalid-name
+        data = T.buffer_proxy(data_ptr)
+        indices = T.buffer_proxy(indices_ptr)
+        updates = T.buffer_proxy(updates_ptr)
+        out = T.buffer_proxy(out_ptr)
+
+        max_threads = 
int(tvm.target.Target.current(allow_none=False).attrs["max_num_threads"])
+
+        with IRBuilder() as ib:
+            with T.seq_scope():
+                # Init
+                nthread_bx_init = cast(ceil_div(full_range, max_threads), 
"int32")
+                tx_init = te.thread_axis("threadIdx.x")
+                bx_init = te.thread_axis("blockIdx.x")
+                with T.frame_scope(
+                    [
+                        T.attr(bx_init, "thread_extent", nthread_bx_init),
+                        T.attr(tx_init, "thread_extent", max_threads),
+                    ]
+                ):
+                    tid = bx_init * max_threads + tx_init
+                    with T.If(tid < full_range):
+                        with T.Then():
+                            out[tid] = data[tid]
+
+                # Scatter
+                nthread_bx_scat = cast(ceil_div(ind_full_range_excl_axis, 
max_threads), "int32")
+                tx_scat = te.thread_axis("threadIdx.x")
+                bx_scat = te.thread_axis("blockIdx.x")
+                with T.frame_scope(
+                    [
+                        T.attr(bx_scat, "thread_extent", nthread_bx_scat),
+                        T.attr(tx_scat, "thread_extent", max_threads),
+                    ]
+                ):
+                    fused = bx_scat * max_threads + tx_scat
+                    with T.If(fused < ind_full_range_excl_axis):
+                        with T.Then():
+                            i = fused // ind_after_axis_range
+                            j = fused % ind_after_axis_range
+                            pre_index1 = i * ind_before_axis_stride + j
+                            pre_index2 = i * before_axis_stride + j
+                            with T.serial(0, ind_axis_range) as k:
+                                # Offset along indices or updates
+                                index1 = pre_index1 + k * ind_after_axis_range
+                                # Get index and shift to positive side if need
+                                k_new = indices[index1]
+                                shifted_index = k_new + (k_new < 0) * 
axis_range
+                                # Offset along data
+                                index2 = pre_index2 + shifted_index * 
after_axis_range
+                                reduce_func(out, index2, updates[index1])
+
+            return ib.get()
+
+    def update_func(dst_ptr, dst_index, update):
+        dst_ptr[dst_index] = update
+
+    def add_func(dst_ptr, dst_index, update):
+        dst_ptr[dst_index] += update
+
+    def mul_func(dst_ptr, dst_index, update):
+        dst_ptr[dst_index] *= update
+
+    def mean_func(dst_ptr, dst_index, update):
+        dst_ptr[dst_index] = (dst_ptr[dst_index] + update) / 2
+
+    def min_func(dst_ptr, dst_index, update):
+        dst_ptr[dst_index] = tirx.min(dst_ptr[dst_index], update)
+
+    def max_func(dst_ptr, dst_index, update):
+        dst_ptr[dst_index] = tirx.max(dst_ptr[dst_index], update)
+
+    reduce_func = None
+    if reduction == "update":
+        reduce_func = update_func
+    elif reduction == "add":
+        reduce_func = add_func
+    elif reduction == "mul":
+        reduce_func = mul_func
+    elif reduction == "mean":
+        reduce_func = mean_func
+    elif reduction == "min":
+        reduce_func = min_func
+    elif reduction == "max":
+        reduce_func = max_func
+    else:
+        raise NotImplementedError(
+            "scatter_elements reduction not in [update, add, mul, mean, min, 
max]:", reduction
+        )
+
+    out_buf = tirx.decl_buffer(data.shape, data.dtype, "out_buf")
+    return te.extern(
+        [data.shape],
+        [data, indices, updates],
+        lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], reduce_func),
+        dtype=data.dtype,
+        out_buffers=[out_buf],
+        name="scatter_elements.gpu",
+        tag="scatter_elements.gpu",
+    )
diff --git a/python/tvm/topi/gpu/scatter_nd.py 
b/python/tvm/topi/gpu/scatter_nd.py
new file mode 100644
index 0000000000..a29cd68a8e
--- /dev/null
+++ b/python/tvm/topi/gpu/scatter_nd.py
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+# ruff: noqa: E741
+"""scatter_nd related operators"""
+
+import tvm
+from tvm import te, tirx  # hide redefinition of min and max
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder import tirx as T
+
+from ..math import cast
+from ..scatter import _verify_scatter_nd_inputs
+from ..utils import ceil_div
+
+
+def scatter_nd(data, indices, updates, mode):
+    """GPU implementation of scatter_nd with explicit thread bindings."""
+    _verify_scatter_nd_inputs(data, indices, updates)
+
+    def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
+        # pylint: disable=invalid-name
+        data = T.buffer_proxy(data_ptr)
+        indices = T.buffer_proxy(indices_ptr)
+        updates = T.buffer_proxy(updates_ptr)
+        out = T.buffer_proxy(out_ptr)
+
+        # We combine all the indices dimensions but the first one into a single
+        # dimension so we can iterate it in single loop instead of an arbitrary
+        # number of loops. We do the same thing for all the update dimensions.
+        fused_indices_dimension = 1
+        for i in indices_ptr.shape[1:]:
+            fused_indices_dimension *= i
+
+        fused_updates_dimension = 1
+        for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
+            fused_updates_dimension *= i
+
+        fused_shape = 1
+        for i in data_ptr.shape:
+            fused_shape *= i
+
+        max_threads = 
int(tvm.target.Target.current(allow_none=False).attrs["max_num_threads"])
+
+        with IRBuilder() as ib:
+            with T.seq_scope():
+                # Init
+                nthread_bx_init = cast(ceil_div(fused_shape, max_threads), 
"int32")
+                tx_init = te.thread_axis("threadIdx.x")
+                bx_init = te.thread_axis("blockIdx.x")
+                with T.frame_scope(
+                    [
+                        T.attr(bx_init, "thread_extent", nthread_bx_init),
+                        T.attr(tx_init, "thread_extent", max_threads),
+                    ]
+                ):
+                    tid = bx_init * max_threads + tx_init
+                    with T.If(tid < fused_shape):
+                        with T.Then():
+                            out[tid] = data[tid]
+
+                # Scatter
+                nthread_bx_scat = cast(ceil_div(fused_updates_dimension, 
max_threads), "int32")
+                tx_scat = te.thread_axis("threadIdx.x")
+                bx_scat = te.thread_axis("blockIdx.x")
+                with T.frame_scope(
+                    [
+                        T.attr(bx_scat, "thread_extent", nthread_bx_scat),
+                        T.attr(tx_scat, "thread_extent", max_threads),
+                    ]
+                ):
+                    j = bx_scat * max_threads + tx_scat
+                    with T.If(j < fused_updates_dimension):
+                        with T.Then():
+                            with T.serial(0, fused_indices_dimension) as i:
+                                offset = fused_updates_dimension
+                                index = j  # x_M, .. x_{N-1} part of the index 
into out.
+                                # Build up the indices[0, y_0, ..], ..,
+                                # indices[M-1, y_0, ..] part of the index into 
out.
+                                for l in 
reversed(range(indices_ptr.shape[0].value)):
+                                    # indices[l, y_0, ... y_{k-1}]
+                                    index += offset * indices[i + l * 
fused_indices_dimension]
+                                    offset *= data_ptr.shape[l]
+                                if mode == "update":
+                                    out[index] = updates[i * 
fused_updates_dimension + j]
+                                elif mode == "add":
+                                    out[index] += updates[i * 
fused_updates_dimension + j]
+                                elif mode == "mul":
+                                    out[index] *= updates[i * 
fused_updates_dimension + j]
+                                elif mode == "min":
+                                    out[index] = tirx.min(
+                                        out[index], updates[i * 
fused_updates_dimension + j]
+                                    )
+                                elif mode == "max":
+                                    out[index] = tirx.max(
+                                        out[index], updates[i * 
fused_updates_dimension + j]
+                                    )
+                                else:
+                                    raise NotImplementedError(
+                                        "scatter_nd mode not in [update, add, 
mul, min, max]:",
+                                        mode,
+                                    )
+
+            return ib.get()
+
+    out_buf = tirx.decl_buffer(data.shape, data.dtype, "out_buf")
+    return te.extern(
+        [data.shape],
+        [data, indices, updates],
+        lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
+        dtype=data.dtype,
+        out_buffers=[out_buf],
+        name="scatter_nd.gpu",
+        tag="scatter_nd.gpu",
+    )
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py 
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 05b6c50c92..a8f1e906f5 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -1551,6 +1551,29 @@ def test_scatter_elements_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
[email protected]_targets("cuda")
+def test_scatter_elements_gpu(target, dev):
+    """scatter_elements lowered for GPU must build"""
+
+    @I.ir_module
+    class Mod:
+        @R.function
+        def main(
+            x: R.Tensor((4, 8), "float32"),
+            indices: R.Tensor((2, 8), "int64"),
+            updates: R.Tensor((2, 8), "float32"),
+        ):
+            with R.dataflow():
+                lv = R.scatter_elements(x, indices, updates, axis=0)
+                gv = lv
+                R.output(gv)
+            return gv
+
+    with tvm.target.Target(target):
+        mod = LegalizeOps()(Mod)
+    relax.build(mod, target=target)
+
+
 def test_layout_transform():
     transformation = lambda a, b, c: (a, c, b // 3, b % 3)
     pad_value = 2
@@ -1838,5 +1861,28 @@ def test_scatter_nd():
     tvm.ir.assert_structural_equal(After, Expected)
 
 
[email protected]_targets("cuda")
+def test_scatter_nd_gpu(target, dev):
+    """scatter_nd lowered for GPU must build"""
+
+    @I.ir_module
+    class Mod:
+        @R.function
+        def main(
+            data: R.Tensor((4, 8), "float32"),
+            indices: R.Tensor((3, 2), "int64"),
+            updates: R.Tensor((3,), "float32"),
+        ):
+            with R.dataflow():
+                lv = R.scatter_nd(data, indices, updates)
+                gv = lv
+                R.output(gv)
+            return gv
+
+    with tvm.target.Target(target):
+        mod = LegalizeOps()(Mod)
+    relax.build(mod, target=target)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to