This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fde09d2052 [BugFix][Relax] Fix scatter_elements and scatter_nd CUDA
compilation (#19497)
fde09d2052 is described below
commit fde09d2052d1b7238d6ac61f0b083baf64d7c098
Author: as4230 <[email protected]>
AuthorDate: Mon May 4 04:30:00 2026 -0400
[BugFix][Relax] Fix scatter_elements and scatter_nd CUDA compilation
(#19497)
`topi.scatter_elements` and `topi.scatter_nd` emit bare `T.parallel`
loops in their te.extern IRBuilder bodies which trips `VerifyMemory` on
CUDA targets:
RuntimeError: Memory verification failed
...
Did you forget to bind?
CPU (LLVM) is unaffected.
This fix makes the IRBuilder body in both `topi/scatter_elements.py` and
`topi/scatter.py` target-aware. When `Target.current()` is a GPU target
it emits thread bindings instead of `T.parallel`.
Fixes #19451.
---
.../tvm/relax/transform/legalize_ops/manipulate.py | 12 +-
python/tvm/topi/gpu/__init__.py | 2 +
python/tvm/topi/gpu/scatter_elements.py | 162 +++++++++++++++++++++
python/tvm/topi/gpu/scatter_nd.py | 129 ++++++++++++++++
.../test_transform_legalize_ops_manipulate.py | 46 ++++++
5 files changed, 349 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/transform/legalize_ops/manipulate.py
b/python/tvm/relax/transform/legalize_ops/manipulate.py
index 2a1d249ef7..fc7ee0d12e 100644
--- a/python/tvm/relax/transform/legalize_ops/manipulate.py
+++ b/python/tvm/relax/transform/legalize_ops/manipulate.py
@@ -235,10 +235,16 @@ def _meshgrid(bb: BlockBuilder, call: Call) -> Expr:
)
+def _is_gpu_target():
+ target = tvm.target.Target.current(allow_none=True)
+ return target is not None and "gpu" in target.keys
+
+
@register_legalize("relax.scatter_elements")
def _scatter_elements(bb: BlockBuilder, call: Call) -> Expr:
+ te_func = topi.gpu.scatter_elements if _is_gpu_target() else
topi.scatter_elements
return bb.call_te(
- topi.scatter_elements,
+ te_func,
call.args[0],
call.args[1],
call.args[2],
@@ -250,10 +256,12 @@ def _scatter_elements(bb: BlockBuilder, call: Call) ->
Expr:
@register_legalize("relax.scatter_nd")
def _scatter_nd(bb: BlockBuilder, call: Call) -> Expr:
# TODO(relax-team): Support native scatter_nd without te extern
+ base_te = topi.gpu.scatter_nd if _is_gpu_target() else topi.scatter_nd
+
def scatter_nd(data, indices, updates, reduction):
axes = list(range(len(indices.shape)))
indices = topi.transpose(indices, axes[-1:] + axes[:-1])
- return topi.scatter_nd(data, indices, updates, reduction)
+ return base_te(data, indices, updates, reduction)
return bb.call_te(
scatter_nd,
diff --git a/python/tvm/topi/gpu/__init__.py b/python/tvm/topi/gpu/__init__.py
index e56a1d7123..69998957f3 100644
--- a/python/tvm/topi/gpu/__init__.py
+++ b/python/tvm/topi/gpu/__init__.py
@@ -20,4 +20,6 @@
"""GPU specific declaration."""
from .scan import cumsum, cumprod
+from .scatter_elements import scatter_elements
+from .scatter_nd import scatter_nd
from .sort import *
diff --git a/python/tvm/topi/gpu/scatter_elements.py
b/python/tvm/topi/gpu/scatter_elements.py
new file mode 100644
index 0000000000..a7d9421862
--- /dev/null
+++ b/python/tvm/topi/gpu/scatter_elements.py
@@ -0,0 +1,162 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""scatter_elements related operators"""
+
+import tvm
+from tvm import te, tirx
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder import tirx as T
+
+from .. import utils
+from ..math import cast
+from ..utils import ceil_div
+
+
+def scatter_elements(data, indices, updates, axis=0, reduction="update"):
+ """GPU implementation of scatter_elements with explicit thread bindings"""
+ if not isinstance(axis, int):
+ axis = utils.get_const_int(axis)
+
+ # Prepare ranges and strides
+ shape = data.shape
+ if axis < 0:
+ axis = len(shape) + axis
+ axis_range = cast(shape[axis], indices.dtype)
+
+ full_range = 1
+ after_axis_range = 1
+ for i, value in enumerate(shape, 0):
+ full_range *= value
+ if i > axis:
+ after_axis_range *= value
+ before_axis_stride = axis_range * after_axis_range
+
+ ind_shape = indices.shape
+ ind_axis_range = ind_shape[axis]
+
+ ind_before_axis_range = 1
+ ind_after_axis_range = 1
+ for i, value in enumerate(ind_shape, 0):
+ if i < axis:
+ ind_before_axis_range *= value
+ elif i > axis:
+ ind_after_axis_range *= value
+ ind_before_axis_stride = ind_axis_range * ind_after_axis_range
+ ind_full_range_excl_axis = ind_before_axis_range * ind_after_axis_range
+
+ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr, reduce_func):
+ # pylint: disable=invalid-name
+ data = T.buffer_proxy(data_ptr)
+ indices = T.buffer_proxy(indices_ptr)
+ updates = T.buffer_proxy(updates_ptr)
+ out = T.buffer_proxy(out_ptr)
+
+ max_threads =
int(tvm.target.Target.current(allow_none=False).attrs["max_num_threads"])
+
+ with IRBuilder() as ib:
+ with T.seq_scope():
+ # Init
+ nthread_bx_init = cast(ceil_div(full_range, max_threads),
"int32")
+ tx_init = te.thread_axis("threadIdx.x")
+ bx_init = te.thread_axis("blockIdx.x")
+ with T.frame_scope(
+ [
+ T.attr(bx_init, "thread_extent", nthread_bx_init),
+ T.attr(tx_init, "thread_extent", max_threads),
+ ]
+ ):
+ tid = bx_init * max_threads + tx_init
+ with T.If(tid < full_range):
+ with T.Then():
+ out[tid] = data[tid]
+
+ # Scatter
+ nthread_bx_scat = cast(ceil_div(ind_full_range_excl_axis,
max_threads), "int32")
+ tx_scat = te.thread_axis("threadIdx.x")
+ bx_scat = te.thread_axis("blockIdx.x")
+ with T.frame_scope(
+ [
+ T.attr(bx_scat, "thread_extent", nthread_bx_scat),
+ T.attr(tx_scat, "thread_extent", max_threads),
+ ]
+ ):
+ fused = bx_scat * max_threads + tx_scat
+ with T.If(fused < ind_full_range_excl_axis):
+ with T.Then():
+ i = fused // ind_after_axis_range
+ j = fused % ind_after_axis_range
+ pre_index1 = i * ind_before_axis_stride + j
+ pre_index2 = i * before_axis_stride + j
+ with T.serial(0, ind_axis_range) as k:
+ # Offset along indices or updates
+ index1 = pre_index1 + k * ind_after_axis_range
+ # Get index and shift to positive side if need
+ k_new = indices[index1]
+ shifted_index = k_new + (k_new < 0) *
axis_range
+ # Offset along data
+ index2 = pre_index2 + shifted_index *
after_axis_range
+ reduce_func(out, index2, updates[index1])
+
+ return ib.get()
+
+ def update_func(dst_ptr, dst_index, update):
+ dst_ptr[dst_index] = update
+
+ def add_func(dst_ptr, dst_index, update):
+ dst_ptr[dst_index] += update
+
+ def mul_func(dst_ptr, dst_index, update):
+ dst_ptr[dst_index] *= update
+
+ def mean_func(dst_ptr, dst_index, update):
+ dst_ptr[dst_index] = (dst_ptr[dst_index] + update) / 2
+
+ def min_func(dst_ptr, dst_index, update):
+ dst_ptr[dst_index] = tirx.min(dst_ptr[dst_index], update)
+
+ def max_func(dst_ptr, dst_index, update):
+ dst_ptr[dst_index] = tirx.max(dst_ptr[dst_index], update)
+
+ reduce_func = None
+ if reduction == "update":
+ reduce_func = update_func
+ elif reduction == "add":
+ reduce_func = add_func
+ elif reduction == "mul":
+ reduce_func = mul_func
+ elif reduction == "mean":
+ reduce_func = mean_func
+ elif reduction == "min":
+ reduce_func = min_func
+ elif reduction == "max":
+ reduce_func = max_func
+ else:
+ raise NotImplementedError(
+ "scatter_elements reduction not in [update, add, mul, mean, min,
max]:", reduction
+ )
+
+ out_buf = tirx.decl_buffer(data.shape, data.dtype, "out_buf")
+ return te.extern(
+ [data.shape],
+ [data, indices, updates],
+ lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0], reduce_func),
+ dtype=data.dtype,
+ out_buffers=[out_buf],
+ name="scatter_elements.gpu",
+ tag="scatter_elements.gpu",
+ )
diff --git a/python/tvm/topi/gpu/scatter_nd.py
b/python/tvm/topi/gpu/scatter_nd.py
new file mode 100644
index 0000000000..a29cd68a8e
--- /dev/null
+++ b/python/tvm/topi/gpu/scatter_nd.py
@@ -0,0 +1,129 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+# ruff: noqa: E741
+"""scatter_nd related operators"""
+
+import tvm
+from tvm import te, tirx # hide redefinition of min and max
+from tvm.script.ir_builder import IRBuilder
+from tvm.script.ir_builder import tirx as T
+
+from ..math import cast
+from ..scatter import _verify_scatter_nd_inputs
+from ..utils import ceil_div
+
+
+def scatter_nd(data, indices, updates, mode):
+ """GPU implementation of scatter_nd with explicit thread bindings."""
+ _verify_scatter_nd_inputs(data, indices, updates)
+
+ def gen_ir(data_ptr, indices_ptr, updates_ptr, out_ptr):
+ # pylint: disable=invalid-name
+ data = T.buffer_proxy(data_ptr)
+ indices = T.buffer_proxy(indices_ptr)
+ updates = T.buffer_proxy(updates_ptr)
+ out = T.buffer_proxy(out_ptr)
+
+ # We combine all the indices dimensions but the first one into a single
+ # dimension so we can iterate it in single loop instead of an arbitrary
+ # number of loops. We do the same thing for all the update dimensions.
+ fused_indices_dimension = 1
+ for i in indices_ptr.shape[1:]:
+ fused_indices_dimension *= i
+
+ fused_updates_dimension = 1
+ for i in updates_ptr.shape[len(indices_ptr.shape) - 1 :]:
+ fused_updates_dimension *= i
+
+ fused_shape = 1
+ for i in data_ptr.shape:
+ fused_shape *= i
+
+ max_threads =
int(tvm.target.Target.current(allow_none=False).attrs["max_num_threads"])
+
+ with IRBuilder() as ib:
+ with T.seq_scope():
+ # Init
+ nthread_bx_init = cast(ceil_div(fused_shape, max_threads),
"int32")
+ tx_init = te.thread_axis("threadIdx.x")
+ bx_init = te.thread_axis("blockIdx.x")
+ with T.frame_scope(
+ [
+ T.attr(bx_init, "thread_extent", nthread_bx_init),
+ T.attr(tx_init, "thread_extent", max_threads),
+ ]
+ ):
+ tid = bx_init * max_threads + tx_init
+ with T.If(tid < fused_shape):
+ with T.Then():
+ out[tid] = data[tid]
+
+ # Scatter
+ nthread_bx_scat = cast(ceil_div(fused_updates_dimension,
max_threads), "int32")
+ tx_scat = te.thread_axis("threadIdx.x")
+ bx_scat = te.thread_axis("blockIdx.x")
+ with T.frame_scope(
+ [
+ T.attr(bx_scat, "thread_extent", nthread_bx_scat),
+ T.attr(tx_scat, "thread_extent", max_threads),
+ ]
+ ):
+ j = bx_scat * max_threads + tx_scat
+ with T.If(j < fused_updates_dimension):
+ with T.Then():
+ with T.serial(0, fused_indices_dimension) as i:
+ offset = fused_updates_dimension
+ index = j # x_M, .. x_{N-1} part of the index
into out.
+ # Build up the indices[0, y_0, ..], ..,
+ # indices[M-1, y_0, ..] part of the index into
out.
+ for l in
reversed(range(indices_ptr.shape[0].value)):
+ # indices[l, y_0, ... y_{k-1}]
+ index += offset * indices[i + l *
fused_indices_dimension]
+ offset *= data_ptr.shape[l]
+ if mode == "update":
+ out[index] = updates[i *
fused_updates_dimension + j]
+ elif mode == "add":
+ out[index] += updates[i *
fused_updates_dimension + j]
+ elif mode == "mul":
+ out[index] *= updates[i *
fused_updates_dimension + j]
+ elif mode == "min":
+ out[index] = tirx.min(
+ out[index], updates[i *
fused_updates_dimension + j]
+ )
+ elif mode == "max":
+ out[index] = tirx.max(
+ out[index], updates[i *
fused_updates_dimension + j]
+ )
+ else:
+ raise NotImplementedError(
+ "scatter_nd mode not in [update, add,
mul, min, max]:",
+ mode,
+ )
+
+ return ib.get()
+
+ out_buf = tirx.decl_buffer(data.shape, data.dtype, "out_buf")
+ return te.extern(
+ [data.shape],
+ [data, indices, updates],
+ lambda ins, outs: gen_ir(ins[0], ins[1], ins[2], outs[0]),
+ dtype=data.dtype,
+ out_buffers=[out_buf],
+ name="scatter_nd.gpu",
+ tag="scatter_nd.gpu",
+ )
diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py
b/tests/python/relax/test_transform_legalize_ops_manipulate.py
index 05b6c50c92..a8f1e906f5 100644
--- a/tests/python/relax/test_transform_legalize_ops_manipulate.py
+++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py
@@ -1551,6 +1551,29 @@ def test_scatter_elements_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
[email protected]_targets("cuda")
+def test_scatter_elements_gpu(target, dev):
+ """scatter_elements lowered for GPU must build"""
+
+ @I.ir_module
+ class Mod:
+ @R.function
+ def main(
+ x: R.Tensor((4, 8), "float32"),
+ indices: R.Tensor((2, 8), "int64"),
+ updates: R.Tensor((2, 8), "float32"),
+ ):
+ with R.dataflow():
+ lv = R.scatter_elements(x, indices, updates, axis=0)
+ gv = lv
+ R.output(gv)
+ return gv
+
+ with tvm.target.Target(target):
+ mod = LegalizeOps()(Mod)
+ relax.build(mod, target=target)
+
+
def test_layout_transform():
transformation = lambda a, b, c: (a, c, b // 3, b % 3)
pad_value = 2
@@ -1838,5 +1861,28 @@ def test_scatter_nd():
tvm.ir.assert_structural_equal(After, Expected)
[email protected]_targets("cuda")
+def test_scatter_nd_gpu(target, dev):
+ """scatter_nd lowered for GPU must build"""
+
+ @I.ir_module
+ class Mod:
+ @R.function
+ def main(
+ data: R.Tensor((4, 8), "float32"),
+ indices: R.Tensor((3, 2), "int64"),
+ updates: R.Tensor((3,), "float32"),
+ ):
+ with R.dataflow():
+ lv = R.scatter_nd(data, indices, updates)
+ gv = lv
+ R.output(gv)
+ return gv
+
+ with tvm.target.Target(target):
+ mod = LegalizeOps()(Mod)
+ relax.build(mod, target=target)
+
+
if __name__ == "__main__":
tvm.testing.main()