This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 1c209e27b7 [Relax] Clean up scatter_elements unknown dtype handling
(#18577)
1c209e27b7 is described below
commit 1c209e27b7b0c62fcb37968382ffcd1612319eab
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sat Dec 20 19:47:33 2025 +0800
[Relax] Clean up scatter_elements unknown dtype handling (#18577)
## Why
- LOG(WARNING) is the standard and correct approach throughout the TVM
codebase
- The existing pattern is used consistently in all relax ops (see
test_op_manipulate.py, index.cc, etc.)
- Added test coverage for previously untested scenarios
---
src/relax/op/tensor/manipulate.cc | 2 --
tests/python/relax/test_op_manipulate.py | 14 ++++++++++++++
2 files changed, 14 insertions(+), 2 deletions(-)
diff --git a/src/relax/op/tensor/manipulate.cc
b/src/relax/op/tensor/manipulate.cc
index 493198fbd0..1aab52ac56 100644
--- a/src/relax/op/tensor/manipulate.cc
+++ b/src/relax/op/tensor/manipulate.cc
@@ -2456,7 +2456,6 @@ StructInfo InferStructInfoScatterElements(const Call&
call, const BlockBuilder&
if (data_sinfo->IsUnknownDtype() || updates_sinfo->IsUnknownDtype()) {
auto diag_dtype = [&](const TensorStructInfoNode* sinfo, ffi::String name)
{
if (sinfo->IsUnknownDtype()) {
- // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for
warning?
LOG(WARNING) << "Data type of " << name
<< " has not been specified. Assume it has an integer
type.";
}
@@ -2473,7 +2472,6 @@ StructInfo InferStructInfoScatterElements(const Call&
call, const BlockBuilder&
}
if (indices_sinfo->IsUnknownDtype()) {
- // TODO(tvm-team): Do we have an equivalent of `ctx->ReportFatal` for
warning?
LOG(WARNING) << "Data type of indice has not been specified. Assume it has
an integer type.";
} else if (!(indices_sinfo->dtype.is_int() ||
indices_sinfo->dtype.is_uint())) {
ctx->ReportFatal(
diff --git a/tests/python/relax/test_op_manipulate.py
b/tests/python/relax/test_op_manipulate.py
index d39584e06b..6a73a84fd8 100644
--- a/tests/python/relax/test_op_manipulate.py
+++ b/tests/python/relax/test_op_manipulate.py
@@ -3417,6 +3417,20 @@ def test_scatter_elements_infer_struct_info():
relax.op.scatter_elements(d2, i3, u0, 0, "updates"),
relax.TensorStructInfo(dtype="float32", ndim=-1),
)
+ # Test with unknown dtype for data
+ d_unknown = relax.Var("data", R.Tensor((4, 4)))
+ _check_inference(
+ bb,
+ relax.op.scatter_elements(d_unknown, i0, u0, 0, "updates"),
+ relax.TensorStructInfo((4, 4), dtype=""),
+ )
+ # Test with unknown dtype for updates
+ u_unknown = relax.Var("updates", R.Tensor((2, 2)))
+ _check_inference(
+ bb,
+ relax.op.scatter_elements(d0, i0, u_unknown, 0, "updates"),
+ relax.TensorStructInfo((4, 4), dtype="float32"),
+ )
def test_scatter_elements_infer_struct_info_symbolic_shape():