This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ee1eb3dcf6 [Bug] Fix core dump in InferLayoutRMSNorm and fix typo 
(#18210)
ee1eb3dcf6 is described below

commit ee1eb3dcf61fc6aabb47625eed26cf44ecef862e
Author: chenxinli <39092231+cccxi...@users.noreply.github.com>
AuthorDate: Fri Aug 15 20:28:34 2025 +0800

    [Bug] Fix core dump in InferLayoutRMSNorm and fix typo (#18210)
    
    Fix core dump in InferLayoutRMSNorm and fix typo
---
 python/tvm/relax/op/nn/nn.py | 5 +----
 src/relax/op/nn/nn.cc        | 9 ++++-----
 2 files changed, 5 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 5834cf14d2..a38b31c9bb 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -1813,7 +1813,7 @@ def rms_norm(
 
     .. math::
 
-        out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight + bias
+        out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight
 
     Parameters
     ----------
@@ -1823,9 +1823,6 @@ def rms_norm(
     weight : relax.Expr
         The scale factor.
 
-    bias : relax.Expr
-        The offset factor.
-
     axes : Union[int, List[int]]
         The axes that along which the normalization is applied.
 
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index 344c9bc7a3..3597b16a5b 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -848,13 +848,12 @@ InferLayoutOutput InferLayoutRMSNorm(const Call& call,
 
   LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
   ObjectPtr<RMSNormAttrs> new_attrs = make_object<RMSNormAttrs>(*attrs);
-  std::vector<Integer> new_axis;
+  std::vector<Integer> new_axes;
   for (const auto& axis : attrs->axes) {
-    new_axis.push_back(FindAxis(layout->layout, axis->value));
+    new_axes.push_back(FindAxis(layout->layout, axis->value));
   }
-  new_attrs->axes = std::move(new_axis);
-  return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, 
{layout},
-                           Attrs(new_attrs));
+  new_attrs->axes = std::move(new_axes);
+  return InferLayoutOutput({layout, initial_layouts[1]}, {layout}, 
Attrs(new_attrs));
 }
 
 TVM_REGISTER_OP("relax.nn.rms_norm")

Reply via email to