This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new dc113955dc [Unity][OP] Add `rms_norm` (#15314)
dc113955dc is described below
commit dc113955dcfc304bb738bb08ce2fa3f5d52f92d8
Author: Yaxing Cai <[email protected]>
AuthorDate: Fri Jul 14 13:17:54 2023 -0700
[Unity][OP] Add `rms_norm` (#15314)
This PR introduces the operator root mean square, `rms_norm`, into TOPI and
relax, and its legalize transform.
---
include/tvm/relax/attrs/nn.h | 11 +
include/tvm/topi/nn/rms_norm.h | 94 ++++++++
python/tvm/relax/op/nn/nn.py | 40 ++++
python/tvm/relax/transform/legalize_ops/nn.py | 11 +
python/tvm/topi/nn/__init__.py | 1 +
python/tvm/topi/nn/rms_norm.py | 45 ++++
python/tvm/topi/testing/__init__.py | 1 +
python/tvm/topi/testing/rms_norm_python.py | 48 ++++
src/relax/op/nn/nn.cc | 59 +++++
src/relax/op/nn/nn.h | 3 +
src/topi/nn.cc | 6 +
.../python/relax/test_transform_legalize_ops_nn.py | 260 +++++++++++++++++++++
tests/python/topi/python/test_topi_rms_norm.py | 60 +++++
13 files changed, 639 insertions(+)
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 3368b66983..a59cf5e71f 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -296,6 +296,17 @@ struct GroupNormAttrs : public
tvm::AttrsNode<GroupNormAttrs> {
}
}; // struct GroupNormAttrs
+/*! \brief Attributes used in rms_norm operator */
+struct RMSNormAttrs : public tvm::AttrsNode<RMSNormAttrs> {
+ Array<Integer> axes;
+ double epsilon;
+
+ TVM_DECLARE_ATTRS(RMSNormAttrs, "relax.attrs.RMSNormAttrs") {
+ TVM_ATTR_FIELD(axes).describe("The axes that along which the normalization
is applied.");
+ TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid
dividing by zero");
+ }
+}; // struct RMSNormAttrs
+
/*! \brief Attributes used in nll_loss operator */
struct NLLLossAttrs : public tvm::AttrsNode<NLLLossAttrs> {
String reduction;
diff --git a/include/tvm/topi/nn/rms_norm.h b/include/tvm/topi/nn/rms_norm.h
new file mode 100644
index 0000000000..e743205611
--- /dev/null
+++ b/include/tvm/topi/nn/rms_norm.h
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief root mean square normalization op constructions
+ * \file nn/rms_norm.h
+ */
+#ifndef TVM_TOPI_NN_RMS_NORM_H_
+#define TVM_TOPI_NN_RMS_NORM_H_
+
+#include <tvm/te/operation.h>
+#include <tvm/topi/reduction.h>
+#include <tvm/topi/tags.h>
+
+#include <string>
+
+namespace tvm {
+namespace topi {
+namespace nn {
+
+using namespace tvm::te;
+
+/*!
+ * \brief Root mean square normalization.
+ * \param data N-D tensor with shape [d_0, d_1, ..., d_{N-1}]
+ * \param weight K-D tensor with shape [r_0, r_1, ..., r_{K-1}] where K ==
len(axis) and
+ * d_{axis_k} == r_k
+ * \param axis The axis to normalize over.
+ * \param epsilon The epsilon value to avoid division by zero.
+ * \param name The name of the operation.
+ * \param tag The tag to mark the operation.
+ * \return The normalized tensor, with the same shape as data.
+ */
+inline Tensor rms_norm(const Tensor& data, const Tensor& weight, const
Array<Integer>& axis,
+ double epsilon, std::string name = "T_rms_norm",
+ std::string tag = kInjective) {
+ const auto& data_type = data->dtype;
+ const auto& weight_type = weight.defined() ? weight->dtype : data_type;
+ ICHECK(data_type == weight_type) << "rms_norm: data and weight must have the
same type";
+ ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16))
+ << "rms_norm: only support float32 and float16 for now";
+ bool is_float16 = data_type == DataType::Float(16);
+
+ auto x = is_float16 ? cast(data, DataType::Float(32)) : data;
+ auto w = is_float16 ? cast(weight, DataType::Float(32)) : weight;
+ auto square = multiply(x, x);
+ auto square_sum = sum(square, axis, /*keepdims=*/false, /*atleast1d=*/true);
+
+ auto ndim = data->shape.size();
+ ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
+ auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
+ auto reduce_extent = make_const(data->dtype, 1);
+ for (int i : real_axis) {
+ reduce_extent *= data->shape[i];
+ }
+ auto rms_norm_func = [&](const Array<Var>& indices) {
+ Array<Var> reduce_indices, non_reduce_indices;
+ for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
+ if (std::find(real_axis.begin(), real_axis.end(), i) != real_axis.end())
{
+ reduce_indices.push_back(indices[i]);
+ } else {
+ non_reduce_indices.push_back(indices[i]);
+ }
+ }
+ auto output =
+ x(indices) * w(reduce_indices) *
+ tvm::rsqrt(square_sum(non_reduce_indices) / reduce_extent +
make_const(data_type, epsilon));
+ return output;
+ };
+ auto rms_norm = tvm::te::compute(data->shape, rms_norm_func, name, tag);
+ return is_float16 ? cast(rms_norm, DataType::Float(16)) : rms_norm;
+}
+
+} // namespace nn
+} // namespace topi
+} // namespace tvm
+
+#endif // TVM_TOPI_NN_RMS_NORM_H_
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 9c4044636c..a88542e05a 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -925,6 +925,46 @@ def group_norm(
)
+def rms_norm(
+ data: Expr,
+ weight: Expr,
+ axes: Union[int, List[int]],
+ epsilon: float = 1e-5,
+) -> Expr:
+ r"""
+ Root mean square normalization (Biao Zhang and et al., 2019).
+ Applies root mean square normalization to the n-dimensional input array.
+ This operator takes an n-dimensional input array and normalizes
+ the input using the given axis:
+
+ .. math::
+
+ out = \frac{data}{\sqrt{mean(data, axis)+\epsilon}} * weight
+
+ Parameters
+ ----------
+ data : relax.Expr
+ Input to which rms_norm will be applied.
+
+ weight : relax.Expr
+ The scale factor.
+
+ axes : Union[int, List[int]]
+ The axes that along which the normalization is applied.
+
+ epsilon : float
+ Small float added to variance to avoid dividing by zero.
+
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ if isinstance(axes, int):
+ axes = [axes]
+ return _ffi_api.rms_norm(data, weight, axes, epsilon) # type: ignore
+
+
def dropout(data: Expr, rate: float = 0.5) -> Expr:
"""Applies the dropout operation to the input tensor.
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index 85986f0240..257b8b79cc 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -334,6 +334,17 @@ def _nn_group_norm(bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.nn.rms_norm")
+def _nn_rms_norm(bb: BlockBuilder, call: Call) -> Expr:
+ return bb.call_te(
+ topi.nn.rms_norm,
+ call.args[0],
+ call.args[1],
+ axis=call.attrs.axes,
+ epsilon=call.attrs.epsilon,
+ )
+
+
@register_legalize("relax.nn.dropout")
def _nn_dropout(bb: BlockBuilder, call: Call) -> Expr:
logging.info("Dropout is handled by frontend translator at this moment and
is not legalized.")
diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py
index d65c5c45c7..2c549cc5b9 100644
--- a/python/tvm/topi/nn/__init__.py
+++ b/python/tvm/topi/nn/__init__.py
@@ -41,6 +41,7 @@ from .upsampling import *
from .instance_norm import instance_norm
from .layer_norm import layer_norm
from .group_norm import group_norm
+from .rms_norm import rms_norm
from .local_response_norm import *
from .bitserial_conv2d import *
from .bitserial_dense import *
diff --git a/python/tvm/topi/nn/rms_norm.py b/python/tvm/topi/nn/rms_norm.py
new file mode 100644
index 0000000000..651ff361bf
--- /dev/null
+++ b/python/tvm/topi/nn/rms_norm.py
@@ -0,0 +1,45 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Root mean square normalization operator."""
+from .. import cpp
+
+
+def rms_norm(data, weight, axis, epsilon=1e-5):
+ """Root mean square normalization operator.
+ It accepts fp16 and fp32 as input data type. It will cast the input to fp32
+ to perform the computation. The output will have the same data type as
input.
+
+ Parameters
+ ----------
+ data : tvm.te.Tensor
+ N-D with shape (d_0, d_1, ..., d_{N-1})
+
+ weight: tvm.te.Tensor
+ K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and
d_{axis_k} == r_k
+
+ axis : list of int
+ Axis over the normalization applied
+
+ epsilon : float
+ The epsilon value to avoid division by zero.
+
+ Returns
+ -------
+ result : tvm.te.Tensor
+ N-D with shape (d_0, d_1, ..., d_{N-1})
+ """
+ return cpp.nn.rms_norm(data, weight, axis, epsilon)
diff --git a/python/tvm/topi/testing/__init__.py
b/python/tvm/topi/testing/__init__.py
index d950a20c05..093f84d99b 100644
--- a/python/tvm/topi/testing/__init__.py
+++ b/python/tvm/topi/testing/__init__.py
@@ -46,6 +46,7 @@ from .roi_pool_python import roi_pool_nchw_python
from .instance_norm_python import instance_norm_python
from .layer_norm_python import layer_norm_python
from .group_norm_python import group_norm_python
+from .rms_norm_python import rms_norm_python
from .lrn_python import lrn_python
from .l2_normalize_python import l2_normalize_python
from .gather_python import gather_python
diff --git a/python/tvm/topi/testing/rms_norm_python.py
b/python/tvm/topi/testing/rms_norm_python.py
new file mode 100644
index 0000000000..0273b41941
--- /dev/null
+++ b/python/tvm/topi/testing/rms_norm_python.py
@@ -0,0 +1,48 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
+"""Root mean square normalization in python"""
+import numpy as np
+
+
+def rms_norm_python(data, weight, axis, epsilon=1e-5):
+ """Root mean square normalization operator in Python.
+
+ Parameters
+ ----------
+ data : numpy.ndarray
+ N-D with shape (d_0, d_1, ..., d_{N-1})
+
+ weight: numpy.ndarray
+ K-D with shape (r_0, r_1, ..., r_{K-1}) where K == len(axis) and
d_{axis_k} == r_k
+
+ axis : int or tuple of ints
+ Axis over the normalization applied
+
+ epsilon : float
+ The epsilon value to avoid division by zero.
+
+ Returns
+ -------
+ result : np.ndarray
+ N-D with shape (d_0, d_1, ..., d_{N-1})
+ """
+ old_dtype = data.dtype
+ data = data.astype("float32")
+ square_mean = np.mean(np.square(data), axis, keepdims=True)
+ result = data * weight / np.sqrt(square_mean + epsilon)
+ return result.astype(old_dtype)
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index b0d5b822d2..fa0d9182bb 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -437,6 +437,65 @@ TVM_REGISTER_OP("relax.nn.group_norm")
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.nn.rms_norm */
+TVM_REGISTER_NODE_TYPE(RMSNormAttrs);
+
+Expr rms_norm(Expr data, Expr weight, Array<Integer> axes, double epsilon) {
+ ObjectPtr<RMSNormAttrs> attrs = make_object<RMSNormAttrs>();
+ attrs->axes = std::move(axes);
+ attrs->epsilon = epsilon;
+
+ static const Op& op = Op::Get("relax.nn.rms_norm");
+ return Call(op, {std::move(data), std::move(weight)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.rms_norm").set_body_typed(rms_norm);
+
+StructInfo InferStructInfoRMSNorm(const Call& call, const BlockBuilder& ctx) {
+ Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+
+ const auto* attrs = call->attrs.as<RMSNormAttrs>();
+ bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo,
attrs->axes);
+
+ return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype,
input_sinfo[0]->ndim)
+ : input_sinfo[0];
+}
+
+InferLayoutOutput InferLayoutRMSNorm(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ std::vector<NLayout> initial_layouts;
+ for (size_t i = 0; i < 3; ++i) {
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[i]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support known ndim";
+ initial_layouts.push_back(InitialLayoutDecision(tensor_sinfo->ndim));
+ }
+ const auto* attrs = call->attrs.as<RMSNormAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ ObjectPtr<RMSNormAttrs> new_attrs = make_object<RMSNormAttrs>(*attrs);
+ std::vector<Integer> new_axis;
+ for (const auto& axis : attrs->axes) {
+ new_axis.push_back(FindAxis(layout->layout, axis->value));
+ }
+ new_attrs->axes = std::move(new_axis);
+ return InferLayoutOutput({layout, initial_layouts[1], initial_layouts[2]},
{layout},
+ Attrs(new_attrs));
+}
+
+TVM_REGISTER_OP("relax.nn.rms_norm")
+ .set_attrs_type<RMSNormAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "Input to which batch_norm will be
applied.")
+ .add_argument("weight", "Tensor", "Input to which batch_norm will be
applied.")
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoRMSNorm)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutRMSNorm)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
/* relax.nn.dropout */
TVM_REGISTER_NODE_TYPE(DropoutAttrs);
diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h
index ce6b369b23..a3658fed54 100644
--- a/src/relax/op/nn/nn.h
+++ b/src/relax/op/nn/nn.h
@@ -78,6 +78,9 @@ Expr layer_norm(Expr data, Expr gamma, Expr beta,
Array<Integer> axes, double ep
Expr group_norm(Expr data, Expr gamma, Expr beta, int num_groups, int
channel_axis,
Array<Integer> axes, double epsilon, bool center, bool scale);
+/*! \brief Compute root mean square normalization. */
+Expr rms_norm(Expr data, Expr weight, Array<Integer> axes, double epsilon);
+
/*!
* \brief Applies the dropout operation to the input tensor.
* \param data The input data to the operator.
diff --git a/src/topi/nn.cc b/src/topi/nn.cc
index 58b962da6a..9ce329b206 100644
--- a/src/topi/nn.cc
+++ b/src/topi/nn.cc
@@ -35,6 +35,7 @@
#include <tvm/topi/nn/local_response_norm.h>
#include <tvm/topi/nn/mapping.h>
#include <tvm/topi/nn/pooling.h>
+#include <tvm/topi/nn/rms_norm.h>
#include <tvm/topi/nn/softmax.h>
namespace tvm {
@@ -176,5 +177,10 @@
TVM_REGISTER_GLOBAL("topi.nn.instance_norm").set_body([](TVMArgs args, TVMRetVal
*rv = nn::instance_norm(args[0], args[1], args[2], args[3],
static_cast<double>(args[4]));
});
+/* Ops from nn/rms_norm.h */
+TVM_REGISTER_GLOBAL("topi.nn.rms_norm").set_body([](TVMArgs args, TVMRetValue*
rv) {
+ *rv = nn::rms_norm(args[0], args[1], args[2], static_cast<double>(args[3]));
+});
+
} // namespace topi
} // namespace tvm
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 27c67e728d..e266e9013b 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2742,6 +2742,266 @@ def test_group_norm_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_rms_norm():
+ # fmt: off
+ @tvm.script.ir_module
+ class RMSNorm:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), "float32"), weight: R.Tensor((4,
5), "float32")) -> R.Tensor((2, 3, 4, 5), "float32"):
+ gv: R.Tensor((2, 3, 4, 5), "float32") = R.nn.rms_norm(x, weight,
axes=[-2, -1])
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def rms_norm(
+ A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)),
"float32"),
+ B: T.Buffer((T.int64(4), T.int64(5)), "float32"),
+ T_rms_norm: T.Buffer(
+ (T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32"
+ ),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
+ for ax0, ax1, ax2, ax3 in T.grid(
+ T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+ ):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
+ A[v_ax0, v_ax1, v_ax2, v_ax3] * A[v_ax0, v_ax1, v_ax2,
v_ax3]
+ )
+ for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ with T.block("T_multiply_red"):
+ v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1,
k2, k3])
+ T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3])
+ T.writes(T_multiply_red[v_ax0, v_ax1])
+ with T.init():
+ T_multiply_red[v_ax0, v_ax1] = T.float32(0)
+ T_multiply_red[v_ax0, v_ax1] = (
+ T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0,
v_ax1, v_k2, v_k3]
+ )
+ for ax0, ax1, ax2, ax3 in T.grid(
+ T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+ ):
+ with T.block("T_rms_norm"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(
+ A[v_ax0, v_ax1, v_ax2, v_ax3],
+ B[v_ax2, v_ax3],
+ T_multiply_red[v_ax0, v_ax1],
+ )
+ T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
+ A[v_ax0, v_ax1, v_ax2, v_ax3]
+ * B[v_ax2, v_ax3]
+ * T.rsqrt(
+ T_multiply_red[v_ax0, v_ax1] * T.float32(0.05)
+ + T.float32(1e-05)
+ )
+ )
+
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 4, 5), dtype="float32"),
+ weight: R.Tensor((4, 5), dtype="float32"),
+ ) -> R.Tensor((2, 3, 4, 5), dtype="float32"):
+ cls = Expected
+ gv = R.call_tir(
+ cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 3, 4, 5),
dtype="float32")
+ )
+ return gv
+ # fmt: on
+ mod = LegalizeOps()(RMSNorm)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_rms_norm_fp16():
+ # fmt: off
+ @tvm.script.ir_module
+ class RMSNorm:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), "float16"), weight: R.Tensor((4,
5), "float16")) -> R.Tensor((2, 3, 4, 5), "float16"):
+ gv: R.Tensor((2, 3, 4, 5), "float16") = R.nn.rms_norm(x, weight,
axes=[-2, -1])
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def rms_norm(
+ A: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)),
"float16"),
+ B: T.Buffer((T.int64(4), T.int64(5)), "float16"),
+ T_cast: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)),
"float16"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ T_cast_1 = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ T_cast_2 = T.alloc_buffer((T.int64(4), T.int64(5)))
+ T_multiply = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ T_multiply_red = T.alloc_buffer((T.int64(2), T.int64(3)))
+ T_rms_norm = T.alloc_buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)))
+ for ax0, ax1, ax2, ax3 in T.grid(
+ T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+ ):
+ with T.block("T_cast"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(A[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast(
+ "float32", A[v_ax0, v_ax1, v_ax2, v_ax3]
+ )
+ for ax0, ax1 in T.grid(T.int64(4), T.int64(5)):
+ with T.block("T_cast_1"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(B[v_ax0, v_ax1])
+ T.writes(T_cast_2[v_ax0, v_ax1])
+ T_cast_2[v_ax0, v_ax1] = T.Cast("float32", B[v_ax0, v_ax1])
+ for ax0, ax1, ax2, ax3 in T.grid(
+ T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+ ):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_multiply[v_ax0, v_ax1, v_ax2, v_ax3] = (
+ T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
+ * T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
+ )
+ for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ with T.block("T_multiply_red"):
+ v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1,
k2, k3])
+ T.reads(T_multiply[v_ax0, v_ax1, v_k2, v_k3])
+ T.writes(T_multiply_red[v_ax0, v_ax1])
+ with T.init():
+ T_multiply_red[v_ax0, v_ax1] = T.float32(0)
+ T_multiply_red[v_ax0, v_ax1] = (
+ T_multiply_red[v_ax0, v_ax1] + T_multiply[v_ax0,
v_ax1, v_k2, v_k3]
+ )
+ for ax0, ax1, ax2, ax3 in T.grid(
+ T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+ ):
+ with T.block("T_rms_norm"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(
+ T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3],
+ T_cast_2[v_ax2, v_ax3],
+ T_multiply_red[v_ax0, v_ax1],
+ )
+ T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (
+ T_cast_1[v_ax0, v_ax1, v_ax2, v_ax3]
+ * T_cast_2[v_ax2, v_ax3]
+ * T.rsqrt(
+ T_multiply_red[v_ax0, v_ax1]
+ / T.Cast("float32", T.float16(4) * T.float16(5))
+ + T.float32(1e-05)
+ )
+ )
+ for ax0, ax1, ax2, ax3 in T.grid(
+ T.int64(2), T.int64(3), T.int64(4), T.int64(5)
+ ):
+ with T.block("T_cast_2"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+ T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_cast[v_ax0, v_ax1, v_ax2, v_ax3] = T.Cast(
+ "float16", T_rms_norm[v_ax0, v_ax1, v_ax2, v_ax3]
+ )
+
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 4, 5), dtype="float16"),
+ weight: R.Tensor((4, 5), dtype="float16"),
+ ) -> R.Tensor((2, 3, 4, 5), dtype="float16"):
+ cls = Expected
+ gv = R.call_tir(
+ cls.rms_norm, (x, weight), out_sinfo=R.Tensor((2, 3, 4, 5),
dtype="float16")
+ )
+ return gv
+ # fmt: on
+ mod = LegalizeOps()(RMSNorm)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_rms_norm_symbolic():
+ # fmt: off
+ @tvm.script.ir_module
+ class RMSNorm:
+ @R.function
+ def main(x: R.Tensor(("n", "s", "f"), "float32"), weight:
R.Tensor(("s", "f"), "float32")) -> R.Tensor(("n", "s", "f"), "float32"):
+ n = T.int64()
+ s = T.int64()
+ f = T.int64()
+ gv: R.Tensor((n, s, f), "float32") = R.nn.rms_norm(x, weight,
axes=[1, 2])
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def rms_norm(var_A: T.handle, var_B: T.handle, var_T_rms_norm:
T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ n, s, f = T.int64(), T.int64(), T.int64()
+ A = T.match_buffer(var_A, (n, s, f))
+ B = T.match_buffer(var_B, (s, f))
+ T_rms_norm = T.match_buffer(var_T_rms_norm, (n, s, f))
+ # with T.block("root"):
+ T_multiply = T.alloc_buffer((n, s, f))
+ T_multiply_red = T.alloc_buffer((n,))
+ for ax0, ax1, ax2 in T.grid(n, s, f):
+ with T.block("T_multiply"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(A[v_ax0, v_ax1, v_ax2])
+ T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
+ T_multiply[v_ax0, v_ax1, v_ax2] = (
+ A[v_ax0, v_ax1, v_ax2] * A[v_ax0, v_ax1, v_ax2]
+ )
+ for ax0, k1, k2 in T.grid(n, s, f):
+ with T.block("T_multiply_red"):
+ v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2])
+ T.reads(T_multiply[v_ax0, v_k1, v_k2])
+ T.writes(T_multiply_red[v_ax0])
+ with T.init():
+ T_multiply_red[v_ax0] = T.float32(0)
+ T_multiply_red[v_ax0] = (
+ T_multiply_red[v_ax0] + T_multiply[v_ax0, v_k1, v_k2]
+ )
+ for ax0, ax1, ax2 in T.grid(n, s, f):
+ with T.block("T_rms_norm"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(A[v_ax0, v_ax1, v_ax2], B[v_ax1, v_ax2],
T_multiply_red[v_ax0])
+ T.writes(T_rms_norm[v_ax0, v_ax1, v_ax2])
+ T_rms_norm[v_ax0, v_ax1, v_ax2] = (
+ A[v_ax0, v_ax1, v_ax2]
+ * B[v_ax1, v_ax2]
+ * T.rsqrt(
+ T_multiply_red[v_ax0]
+ / (T.Cast("float32", s) * T.Cast("float32", f))
+ + T.float32(1e-05)
+ )
+ )
+
+ @R.function
+ def main(
+ x: R.Tensor(("n", "s", "f"), dtype="float32"),
+ weight: R.Tensor(("s", "f"), dtype="float32"),
+ ) -> R.Tensor(("n", "s", "f"), dtype="float32"):
+ n = T.int64()
+ s = T.int64()
+ f = T.int64()
+ cls = Expected
+ gv = R.call_tir(
+ cls.rms_norm, (x, weight), out_sinfo=R.Tensor((n, s, f),
dtype="float32")
+ )
+ return gv
+ # fmt: on
+ mod = LegalizeOps()(RMSNorm)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_attention():
# fmt: off
@tvm.script.ir_module
diff --git a/tests/python/topi/python/test_topi_rms_norm.py
b/tests/python/topi/python/test_topi_rms_norm.py
new file mode 100644
index 0000000000..a30c5bbc97
--- /dev/null
+++ b/tests/python/topi/python/test_topi_rms_norm.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for rms_norm."""
+import numpy as np
+import pytest
+import tvm
+from tvm import te
+from tvm import topi
+from tvm.topi.utils import get_const_tuple
+import tvm.topi.testing
+
+import tvm.testing
+
+
+_rms_norm_schedule = {
+ "generic": topi.generic.schedule_injective,
+}
+
+
+# only test on llvm because schedule is missing
[email protected]_targets("llvm")
[email protected]("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1,
2))])
[email protected]("dtype", ["float32", "float16"])
+def test_rms_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-4,
atol=5e-4):
+ data = te.placeholder(shape, dtype=dtype, name="data")
+ scale_shape = [shape[dim] for dim in axis]
+ weight = te.placeholder(scale_shape, dtype=dtype, name="weight")
+ B = topi.nn.rms_norm(data, weight, axis, episilon)
+
+ data_np = np.random.uniform(size=shape).astype(dtype)
+ weight_np = np.random.uniform(size=scale_shape).astype(dtype)
+ b_np = tvm.topi.testing.rms_norm_python(data_np, weight_np, axis, episilon)
+
+ with tvm.target.Target(target):
+ s_func = tvm.topi.testing.dispatch(target, _rms_norm_schedule)
+ s = s_func([B])
+ data_tvm = tvm.nd.array(data_np, dev)
+ weight_tvm = tvm.nd.array(weight_np, dev)
+ b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev)
+ f = tvm.build(s, [data, weight, B], target)
+ f(data_tvm, weight_tvm, b_tvm)
+ tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()