This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new dc113955dc [Unity][OP] Add `rms_norm` (#15314)
dc113955dc is described below

commit dc113955dcfc304bb738bb08ce2fa3f5d52f92d8
Author: Yaxing Cai <[email protected]>
AuthorDate: Fri Jul 14 13:17:54 2023 -0700

    [Unity][OP] Add `rms_norm` (#15314)
    
    This PR introduces the operator root mean square, `rms_norm`, into TOPI and 
relax, and its legalize transform.
---
 include/tvm/relax/attrs/nn.h                       |  11 +
 include/tvm/topi/nn/rms_norm.h                     |  94 ++++++++
 python/tvm/relax/op/nn/nn.py                       |  40 ++++
 python/tvm/relax/transform/legalize_ops/nn.py      |  11 +
 python/tvm/topi/nn/__init__.py                     |   1 +
 python/tvm/topi/nn/rms_norm.py                     |  45 ++++
 python/tvm/topi/testing/__init__.py                |   1 +
 python/tvm/topi/testing/rms_norm_python.py         |  48 ++++
 src/relax/op/nn/nn.cc                              |  59 +++++
 src/relax/op/nn/nn.h                               |   3 +
 src/topi/nn.cc                                     |   6 +
 .../python/relax/test_transform_legalize_ops_nn.py | 260 +++++++++++++++++++++
 tests/python/topi/python/test_topi_rms_norm.py     |  60 +++++
 13 files changed, 639 insertions(+)

diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 3368b66983..a59cf5e71f 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -296,6 +296,17 @@ struct GroupNormAttrs : public 
tvm::AttrsNode<GroupNormAttrs> {
   }
 };  // struct GroupNormAttrs
 
+/*! \brief Attributes used in rms_norm operator */
+struct RMSNormAttrs : public tvm::AttrsNode<RMSNormAttrs> {
+  Array<Integer> axes;
+  double epsilon;
+
+  TVM_DECLARE_ATTRS(RMSNormAttrs, "relax.attrs.RMSNormAttrs") {
+    TVM_ATTR_FIELD(axes).describe("The axes that along which the normalization 
is applied.");
+    TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid 
dividing by zero");
+  }
+};  // struct RMSNormAttrs
+
 /*! \brief Attributes used in nll_loss operator */
 struct NLLLossAttrs : public tvm::AttrsNode<NLLLossAttrs> {
   String reduction;
diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h
new file mode 100644
index 0000000000..e743205611
--- /dev/null
+++ b/include/tvm/topi/nn/rms_norm.h
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief root mean square normalization op constructions
+ * \file nn/rms_norm.h
+ */
+#ifndef TVM_TOPI_NN_RMS_NORM_H_
+#define TVM_TOPI_NN_RMS_NORM_H_
+
+#include <tvm/te/operation.h>
+#include <tvm/topi/reduction.h>
+#include <tvm/topi/tags.h>
+
+#include <string>
+
+namespace tvm {
+namespace topi {
+namespace nn {
+
+using namespace tvm::te;
+
+/*!
+ * \brief Root mean square normalization.
+ * \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
+ * \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K == 
len(axis) and
+ *               d_{axis_k} == r_k
+ * \param axis The axis to normalize over.
+ * \param epsilon The epsilon value to avoid division by zero.
+ * \param name The name of the operation.
+ * \param tag The tag to mark the operation.
+ * \return The normalized tensor, with the same shape as data.
+ */
+inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const 
Array<Integer>& axis,
+                       double epsilon, std::string name = "T_rms_norm",
+                       std::string tag = kInjective) {
+  const auto& data_type = data->dtype;
+  const auto& weight_type = weight.defined() ? weight->dtype : data_type;
+  ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the 
same type";
+  ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16))
+      << "rms_norm: only support float32 and float16 for now";
+  bool is_float16 = data_type == DataType::Float(16);
+
+  auto x = is_float16 ? cast(data, DataType::Float(32)) : data;
+  auto w = is_float16 ? cast(weight, DataType::Float(32)) : weight;
+  auto square = multiply(x, x);
+  auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true);
+
+  auto ndim = data->shape.size();
+  ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
+  auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
+  auto reduce_extent = make_const(data->dtype, 1);
+  for (int i : real_axis) {
+    reduce_extent *= data->shape[i];
+  }
+  auto rms_norm_func = [&](const Array<Var>& indices) {
+    Array<Var> reduce_indices, non_reduce_indices;
+    for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
+      if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end()) 
{
+        reduce_indices.push_back(indices[i]);
+      } else {
+        non_reduce_indices.push_back(indices[i]);
+      }
+    }
+    auto output =
+        x(indices) * w(reduce_indices) *
+        tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent + 
make_const(data_type, epsilon));
+    return output;
+  };
+  auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag);
+  return is_float16 ? cast(rms_norm, DataType::Float(16)) : rms_norm;
+}
+
+}  // namespace nn
+}  // namespace topi
+}  // namespace tvm
+
+#endif  // TVM_TOPI_NN_RMS_NORM_H_
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 9c4044636c..a88542e05a 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -925,6 +925,46 @@ def group_norm(
     )
 
 
+def rms_norm(
+    data: Expr,
+    weight: Expr,
+    axes: Union[int, List[int]],
+    epsilon: float = 1e-5,
+) -> Expr:
+    r"""
+    Root mean square normalization (Biao Zhang and et al., 2019).
+    Applies root mean square normalization to the n-dimensional input array.
+    This operator takes an n-dimensional input array and normalizes
+    the input using the given axis:
+
+    .. math::
+
+        out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight
+
+    Parameters
+    ----------
+    data : relax.Expr
+        Input to which rms_norm will be applied.
+
+    weight : relax.Expr
+        The scale factor.
+
+    axes : Union[int, List[int]]
+        The axes that along which the normalization is applied.
+
+    epsilon : float
+        Small float added to variance to avoid dividing by zero.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(axes, int):
+        axes = [axes]
+    return _ffi_api.rms_norm(data, weight, axes, epsilon)  # type: ignore
+
+
 def dropout(data: Expr, rate: float = 0.5) -> Expr:
     """Applies the dropout operation to the input tensor.
 
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index 85986f0240..257b8b79cc 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -334,6 +334,17 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.nn.rms_norm")
+def _nn_rms_norm(bb: BlockBuilder, call: Call) -> Expr:
+    return bb.call_te(
+        topi.nn.rms_norm,
+        call.args[0],
+        call.args[1],
+        axis=call.attrs.axes,
+        epsilon=call.attrs.epsilon,
+    )
+
+
 @register_legalize("relax.nn.dropout")
 def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
     logging.info("Dropout is handled by frontend translator at this moment and 
is not legalized.")
diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py
index d65c5c45c7..2c549cc5b9 100644
--- a/python/tvm/topi/nn/__init__.py
+++ b/python/tvm/topi/nn/__init__.py
@@ -41,6 +41,7 @@ from .upsampling import *
 from .instance_norm import instance_norm
 from .layer_norm import layer_norm
 from .group_norm import group_norm
+from .rms_norm import rms_norm
 from .local_response_norm import *
 from .bitserial_conv2d import *
 from .bitserial_dense import *
diff --git a/python/tvm/topi/nn/rms_norm.py b/python/tvm/topi/nn/rms_norm.py
new file mode 100644
index 0000000000..651ff361bf
--- /dev/null
+++ b/python/tvm/topi/nn/rms_norm.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Root mean square normalization operator."""
+from .. import cpp
+
+
+def rms_norm(data, weight, axis, epsilon=1e-5):
+    """Root mean square normalization operator.
+    It accepts fp16 and fp32 as input data type. It will cast the input to fp32
+    to perform the computation. The output will have the same data type as 
input.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        N-D with shape (d_0, d_1, ..., d_{N-1})
+
+    weight: tvm.te.Tensor
+        K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and 
d_{axis_k} == r_k
+
+    axis : list of int
+        Axis over the normalization applied
+
+    epsilon : float
+        The epsilon value to avoid division by zero.
+
+    Returns
+    -------
+    result : tvm.te.Tensor
+        N-D with shape (d_0, d_1, ..., d_{N-1})
+    """
+    return cpp.nn.rms_norm(data, weight, axis, epsilon)
diff --git a/python/tvm/topi/testing/__init__.py 
b/python/tvm/topi/testing/__init__.py
index d950a20c05..093f84d99b 100644
--- a/python/tvm/topi/testing/__init__.py
+++ b/python/tvm/topi/testing/__init__.py
@@ -46,6 +46,7 @@ from .roi_pool_python import roi_pool_nchw_python
 from .instance_norm_python import instance_norm_python
 from .layer_norm_python import layer_norm_python
 from .group_norm_python import group_norm_python
+from .rms_norm_python import rms_norm_python
 from .lrn_python import lrn_python
 from .l2_normalize_python import l2_normalize_python
 from .gather_python import gather_python
diff --git a/python/tvm/topi/testing/rms_norm_python.py 
b/python/tvm/topi/testing/rms_norm_python.py
new file mode 100644
index 0000000000..0273b41941
--- /dev/null
+++ b/python/tvm/topi/testing/rms_norm_python.py
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
+"""Root mean square normalization in python"""
+import numpy as np
+
+
+def rms_norm_python(data, weight, axis, epsilon=1e-5):
+    """Root mean square normalization operator in Python.
+
+    Parameters
+    ----------
+    data : numpy.ndarray
+        N-D with shape (d_0, d_1, ..., d_{N-1})
+
+    weight: numpy.ndarray
+        K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and 
d_{axis_k} == r_k
+
+    axis : int or tuple of ints
+        Axis over the normalization applied
+
+    epsilon : float
+        The epsilon value to avoid division by zero.
+
+    Returns
+    -------
+    result : np.ndarray
+        N-D with shape (d_0, d_1, ..., d_{N-1})
+    """
+    old_dtype = data.dtype
+    data = data.astype("float32")
+    square_mean = np.mean(np.square(data), axis, keepdims=True)
+    result = data * weight / np.sqrt(square_mean + epsilon)
+    return result.astype(old_dtype)
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index b0d5b822d2..fa0d9182bb 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -437,6 +437,65 @@ TVM_REGISTER_OP("relax.nn.group_norm")
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.nn.rms_norm */
+TVM_REGISTER_NODE_TYPE(RMSNormAttrs);
+
+Expr rms_norm(Expr data, Expr weight, Array<Integer> axes, double epsilon) {
+  ObjectPtr<RMSNormAttrs> attrs = make_object<RMSNormAttrs>();
+  attrs->axes = std::move(axes);
+  attrs->epsilon = epsilon;
+
+  static const Op& op = Op::Get("relax.nn.rms_norm");
+  return Call(op, {std::move(data), std::move(weight)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.rms_norm").set_body_typed(rms_norm);
+
+StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) {
+  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+
+  const auto* attrs = call->attrs.as<RMSNormAttrs>();
+  bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, 
attrs->axes);
+
+  return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, 
input_sinfo[0]->ndim)
+                       : input_sinfo[0];
+}
+
+InferLayoutOutput InferLayoutRMSNorm(const Call& call,
+                                     const Map<String, Array<String>>& 
desired_layouts,
+                                     const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  std::vector<NLayout> initial_layouts;
+  for (size_t i = 0; i < 3; ++i) {
+    const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[i]);
+    ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+    ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+    initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim));
+  }
+  const auto* attrs = call->attrs.as<RMSNormAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+  ObjectPtr<RMSNormAttrs> new_attrs = make_object<RMSNormAttrs>(*attrs);
+  std::vector<Integer> new_axis;
+  for (const auto& axis : attrs->axes) {
+    new_axis.push_back(FindAxis(layout->layout, axis->value));
+  }
+  new_attrs->axes = std::move(new_axis);
+  return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]}, 
{layout},
+                           Attrs(new_attrs));
+}
+
+TVM_REGISTER_OP("relax.nn.rms_norm")
+    .set_attrs_type<RMSNormAttrs>()
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "Input to which batch_norm will be 
applied.")
+    .add_argument("weight", "Tensor", "Input to which batch_norm will be 
applied.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoRMSNorm)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutRMSNorm)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 /* relax.nn.dropout */
 TVM_REGISTER_NODE_TYPE(DropoutAttrs);
 
diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h
index ce6b369b23..a3658fed54 100644
--- a/src/relax/op/nn/nn.h
+++ b/src/relax/op/nn/nn.h
@@ -78,6 +78,9 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta, 
Array<Integer> axes, double ep
 Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int 
channel_axis,
                 Array<Integer> axes, double epsilon, bool center, bool scale);
 
+/*! \brief Compute root mean square normalization. */
+Expr rms_norm(Expr data, Expr weight, Array<Integer> axes, double epsilon);
+
 /*!
  * \brief Applies the dropout operation to the input tensor.
  * \param data The input data to the operator.
diff --git a/src/topi/nn.cc b/src/topi/nn.cc
index 58b962da6a..9ce329b206 100644
--- a/src/topi/nn.cc
+++ b/src/topi/nn.cc
@@ -35,6 +35,7 @@
 #include <tvm/topi/nn/local_response_norm.h>
 #include <tvm/topi/nn/mapping.h>
 #include <tvm/topi/nn/pooling.h>
+#include <tvm/topi/nn/rms_norm.h>
 #include <tvm/topi/nn/softmax.h>
 
 namespace tvm {
@@ -176,5 +177,10 @@ 
TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetVal
   *rv = nn::instance_norm(args[0], args[1], args[2], args[3], 
static_cast<double>(args[4]));
 });
 
+/* Ops from nn/rms_norm.h */
+TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body([](TVMArgs args, TVMRetValue* 
rv) {
+  *rv = nn::rms_norm(args[0], args[1], args[2], static_cast<double>(args[3]));
+});
+
 }  // namespace topi
 }  // namespace tvm
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 27c67e728d..e266e9013b 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2742,6 +2742,266 @@ def test_group_norm_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_rms_norm():
+    # fmt: off
+    @tvm.script.ir_module
+    class RMSNorm:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4, 5), "float32"), weight: R.Tensor((4, 
5), "float32")) -> R.Tensor((2, 3, 4, 5), "float32"):
+            gv: R.Tensor((2, 3, 4, 5), "float32") = R.nn.rms_norm(x, weight, 
axes=[-2, -1])
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def rms_norm(
+            A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), 
"float32"),
+            B: T.Buffer((T.int64(4), T.int64(5)), "float32"),
+            T_rms_norm: T.Buffer(
+                (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"
+            ),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
+            T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
+            for ax0, ax1, ax2, ax3 in T.grid(
+                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+            ):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
+                        A[v_ax0, v_ax1, v_ax2, v_ax3] * A[v_ax0, v_ax1, v_ax2, 
v_ax3]
+                    )
+            for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
+                with T.block("T_multiply_red"):
+                    v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, 
k2, k3])
+                    T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3])
+                    T.writes(T_multiply_red[v_ax0, v_ax1])
+                    with T.init():
+                        T_multiply_red[v_ax0, v_ax1] = T.float32(0)
+                    T_multiply_red[v_ax0, v_ax1] = (
+                        T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, 
v_ax1, v_k2, v_k3]
+                    )
+            for ax0, ax1, ax2, ax3 in T.grid(
+                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+            ):
+                with T.block("T_rms_norm"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(
+                        A[v_ax0, v_ax1, v_ax2, v_ax3],
+                        B[v_ax2, v_ax3],
+                        T_multiply_red[v_ax0, v_ax1],
+                    )
+                    T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
+                        A[v_ax0, v_ax1, v_ax2, v_ax3]
+                        * B[v_ax2, v_ax3]
+                        * T.rsqrt(
+                            T_multiply_red[v_ax0, v_ax1] * T.float32(0.05)
+                            + T.float32(1e-05)
+                        )
+                    )
+
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 4, 5), dtype="float32"),
+            weight: R.Tensor((4, 5), dtype="float32"),
+        ) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
+            cls = Expected
+            gv = R.call_tir(
+                cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 3, 4, 5), 
dtype="float32")
+            )
+            return gv
+    # fmt: on
+    mod = LegalizeOps()(RMSNorm)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_rms_norm_fp16():
+    # fmt: off
+    @tvm.script.ir_module
+    class RMSNorm:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4, 5), "float16"), weight: R.Tensor((4, 
5), "float16")) -> R.Tensor((2, 3, 4, 5), "float16"):
+            gv: R.Tensor((2, 3, 4, 5), "float16") = R.nn.rms_norm(x, weight, 
axes=[-2, -1])
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def rms_norm(
+            A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), 
"float16"),
+            B: T.Buffer((T.int64(4), T.int64(5)), "float16"),
+            T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), 
"float16"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
+            T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
+            T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
+            T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
+            T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)))
+            for ax0, ax1, ax2, ax3 in T.grid(
+                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+            ):
+                with T.block("T_cast"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast(
+                        "float32", A[v_ax0, v_ax1, v_ax2, v_ax3]
+                    )
+            for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
+                with T.block("T_cast_1"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(B[v_ax0, v_ax1])
+                    T.writes(T_cast_2[v_ax0, v_ax1])
+                    T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1])
+            for ax0, ax1, ax2, ax3 in T.grid(
+                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+            ):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
+                        T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
+                        * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
+                    )
+            for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)):
+                with T.block("T_multiply_red"):
+                    v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1, 
k2, k3])
+                    T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3])
+                    T.writes(T_multiply_red[v_ax0, v_ax1])
+                    with T.init():
+                        T_multiply_red[v_ax0, v_ax1] = T.float32(0)
+                    T_multiply_red[v_ax0, v_ax1] = (
+                        T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0, 
v_ax1, v_k2, v_k3]
+                    )
+            for ax0, ax1, ax2, ax3 in T.grid(
+                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+            ):
+                with T.block("T_rms_norm"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(
+                        T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3],
+                        T_cast_2[v_ax2, v_ax3],
+                        T_multiply_red[v_ax0, v_ax1],
+                    )
+                    T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
+                        T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
+                        * T_cast_2[v_ax2, v_ax3]
+                        * T.rsqrt(
+                            T_multiply_red[v_ax0, v_ax1]
+                            / T.Cast("float32", T.float16(4) * T.float16(5))
+                            + T.float32(1e-05)
+                        )
+                    )
+            for ax0, ax1, ax2, ax3 in T.grid(
+                T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+            ):
+                with T.block("T_cast_2"):
+                    v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, 
ax1, ax2, ax3])
+                    T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3])
+                    T_cast[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast(
+                        "float16", T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]
+                    )
+
+        @R.function
+        def main(
+            x: R.Tensor((2, 3, 4, 5), dtype="float16"),
+            weight: R.Tensor((4, 5), dtype="float16"),
+        ) -> R.Tensor((2, 3, 4, 5), dtype="float16"):
+            cls = Expected
+            gv = R.call_tir(
+                cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 3, 4, 5), 
dtype="float16")
+            )
+            return gv
+    # fmt: on
+    mod = LegalizeOps()(RMSNorm)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_rms_norm_symbolic():
+    # fmt: off
+    @tvm.script.ir_module
+    class RMSNorm:
+        @R.function
+        def main(x: R.Tensor(("n", "s", "f"), "float32"), weight: 
R.Tensor(("s", "f"), "float32")) -> R.Tensor(("n", "s", "f"), "float32"):
+            n = T.int64()
+            s = T.int64()
+            f = T.int64()
+            gv: R.Tensor((n, s, f), "float32") = R.nn.rms_norm(x, weight, 
axes=[1, 2])
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def rms_norm(var_A: T.handle, var_B: T.handle, var_T_rms_norm: 
T.handle):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            n, s, f = T.int64(), T.int64(), T.int64()
+            A = T.match_buffer(var_A, (n, s, f))
+            B = T.match_buffer(var_B, (s, f))
+            T_rms_norm = T.match_buffer(var_T_rms_norm, (n, s, f))
+            # with T.block("root"):
+            T_multiply = T.alloc_buffer((n, s, f))
+            T_multiply_red = T.alloc_buffer((n,))
+            for ax0, ax1, ax2 in T.grid(n, s, f):
+                with T.block("T_multiply"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(A[v_ax0, v_ax1, v_ax2])
+                    T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
+                    T_multiply[v_ax0, v_ax1, v_ax2] = (
+                        A[v_ax0, v_ax1, v_ax2] * A[v_ax0, v_ax1, v_ax2]
+                    )
+            for ax0, k1, k2 in T.grid(n, s, f):
+                with T.block("T_multiply_red"):
+                    v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2])
+                    T.reads(T_multiply[v_ax0, v_k1, v_k2])
+                    T.writes(T_multiply_red[v_ax0])
+                    with T.init():
+                        T_multiply_red[v_ax0] = T.float32(0)
+                    T_multiply_red[v_ax0] = (
+                        T_multiply_red[v_ax0] + T_multiply[v_ax0, v_k1, v_k2]
+                    )
+            for ax0, ax1, ax2 in T.grid(n, s, f):
+                with T.block("T_rms_norm"):
+                    v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+                    T.reads(A[v_ax0, v_ax1, v_ax2], B[v_ax1, v_ax2], 
T_multiply_red[v_ax0])
+                    T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
+                    T_rms_norm[v_ax0, v_ax1, v_ax2] = (
+                        A[v_ax0, v_ax1, v_ax2]
+                        * B[v_ax1, v_ax2]
+                        * T.rsqrt(
+                            T_multiply_red[v_ax0]
+                            / (T.Cast("float32", s) * T.Cast("float32", f))
+                            + T.float32(1e-05)
+                        )
+                    )
+
+        @R.function
+        def main(
+            x: R.Tensor(("n", "s", "f"), dtype="float32"),
+            weight: R.Tensor(("s", "f"), dtype="float32"),
+        ) -> R.Tensor(("n", "s", "f"), dtype="float32"):
+            n = T.int64()
+            s = T.int64()
+            f = T.int64()
+            cls = Expected
+            gv = R.call_tir(
+                cls.rms_norm, (x, weight), out_sinfo=R.Tensor((n, s, f), 
dtype="float32")
+            )
+            return gv
+    # fmt: on
+    mod = LegalizeOps()(RMSNorm)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_attention():
     # fmt: off
     @tvm.script.ir_module
diff --git a/tests/python/topi/python/test_topi_rms_norm.py 
b/tests/python/topi/python/test_topi_rms_norm.py
new file mode 100644
index 0000000000..a30c5bbc97
--- /dev/null
+++ b/tests/python/topi/python/test_topi_rms_norm.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for rms_norm."""
+import numpy as np
+import pytest
+import tvm
+from tvm import te
+from tvm import topi
+from tvm.topi.utils import get_const_tuple
+import tvm.topi.testing
+
+import tvm.testing
+
+
+_rms_norm_schedule = {
+    "generic": topi.generic.schedule_injective,
+}
+
+
+# only test on llvm because schedule is missing
[email protected]_targets("llvm")
[email protected]("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 
2))])
[email protected]("dtype", ["float32", "float16"])
+def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-4, 
atol=5e-4):
+    data = te.placeholder(shape, dtype=dtype, name="data")
+    scale_shape = [shape[dim] for dim in axis]
+    weight = te.placeholder(scale_shape, dtype=dtype, name="weight")
+    B = topi.nn.rms_norm(data, weight, axis, episilon)
+
+    data_np = np.random.uniform(size=shape).astype(dtype)
+    weight_np = np.random.uniform(size=scale_shape).astype(dtype)
+    b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, axis, episilon)
+
+    with tvm.target.Target(target):
+        s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule)
+        s = s_func([B])
+    data_tvm = tvm.nd.array(data_np, dev)
+    weight_tvm = tvm.nd.array(weight_np, dev)
+    b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev)
+    f = tvm.build(s, [data, weight, B], target)
+    f(data_tvm, weight_tvm, b_tvm)
+    tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to