This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 167bc874e3 [Unity][Training] Loss functions and AppendLoss pass 
(#14668)
167bc874e3 is described below

commit 167bc874e3ed75a72de630905bd1c0f6d16cadc5
Author: Yixin Dong <[email protected]>
AuthorDate: Sat Apr 22 23:34:59 2023 +0800

    [Unity][Training] Loss functions and AppendLoss pass (#14668)
    
    This PR adds two components to support the Relax training workflow: loss 
function library and AppendLoss pass.
    - Loss function library. A library that generates relax functions 
representing common loss functions, such as MSELoss and CrossEntropyLoss.
      - Generally, loss function will take one or more **input parameters** 
(that is outputs of the backbone of a model), one or more **target 
parameters**, and generate a scalar value denoting the loss.
    - AppendLoss pass. You can specify one function (the backbone function) in 
the IRModule, and one loss function (normally, generated by the loss function 
library) to AppendLoss. It will attach the loss function to the backbone 
function, and add the result function to the IRModule.
    
    Co-authored-by: Chaofan Lin <[email protected]>
---
 python/tvm/relax/training/__init__.py              |   4 +
 .../relax/training/{__init__.py => _ffi_api.py}    |   5 +-
 python/tvm/relax/training/loss.py                  | 292 ++++++++++++++++++
 python/tvm/relax/training/utils.py                 | 155 ++++++++++
 src/relax/training/utils.cc                        | 225 ++++++++++++++
 src/relax/training/utils.h                         |  60 ++++
 tests/python/relax/test_training_append_loss.py    | 327 +++++++++++++++++++++
 tests/python/relax/test_training_loss.py           | 212 +++++++++++++
 8 files changed, 1278 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/training/__init__.py 
b/python/tvm/relax/training/__init__.py
index 2cf602cb4f..3f9ea486e0 100644
--- a/python/tvm/relax/training/__init__.py
+++ b/python/tvm/relax/training/__init__.py
@@ -17,3 +17,7 @@
 """The Relax training APIs."""
 
 from . import optimizer
+from . import utils
+from . import loss
+
+from .utils import AppendLoss
diff --git a/python/tvm/relax/training/__init__.py 
b/python/tvm/relax/training/_ffi_api.py
similarity index 88%
copy from python/tvm/relax/training/__init__.py
copy to python/tvm/relax/training/_ffi_api.py
index 2cf602cb4f..70cb83fc0e 100644
--- a/python/tvm/relax/training/__init__.py
+++ b/python/tvm/relax/training/_ffi_api.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""The Relax training APIs."""
+"""FFI APIs for tvm.relax.training"""
+import tvm._ffi
 
-from . import optimizer
+tvm._ffi._init_api("relax.training", __name__)
diff --git a/python/tvm/relax/training/loss.py 
b/python/tvm/relax/training/loss.py
new file mode 100644
index 0000000000..466c2996e7
--- /dev/null
+++ b/python/tvm/relax/training/loss.py
@@ -0,0 +1,292 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=redefined-builtin, invalid-name
+"""Loss functions library for relax."""
+
+from typing import Optional, Union
+
+# isort: off
+from typing_extensions import Literal
+
+# isort: on
+
+from ..block_builder import BlockBuilder
+from ..expr import Expr, Var, Function, StructInfo
+
+from ..op import abs, sum, mean, subtract, multiply
+from ..op.nn import log_softmax, nll_loss
+
+
+def _create_param_var(param: Union[Var, StructInfo], param_name: str) -> Var:
+    """If param is a StructInfo, create a Var with the given StructInfo and 
name.
+
+    If param is a Var, create a Var with the same StructInfo and name as the 
given param Var."""
+    if isinstance(param, StructInfo):
+        param = Var(param_name, param)
+    if not isinstance(param, Var):
+        raise TypeError("The type of param should be Var or StructInfo, but 
got " + type(param))
+    return Var(param.name_hint, param.struct_info)
+
+
+class Loss:
+    r"""Base class of all loss.
+
+    Generally, loss function will take one or more **input parameters** (that 
is outputs of
+    the backbone of a model), one or more **target parameters**, and generate 
a scalar value
+    denoting the loss.
+
+    You can use `relax.transform.AppendLoss` to append the loss function to a 
one-dataflowblock
+    backbone function in a IRModule. That will generate a one-dataflowblock 
function accepting
+    instances and targets, and then returning the loss.
+
+    Most loss functions involve a reduction of losses from all instances in a 
batch. We use
+    `reduction` parameter to denote the reduction method. Possible reduction 
methods include
+    `"mean"`, `"sum"` and `"none"`.
+
+    Parameters
+    ----------
+    loss_name : str
+        The name of the loss function. Should be provided when calling 
`super().__init__` in
+        constructor functions of subclasses.
+
+    num_backbone_outputs : int
+        The number of `prediction_outputs` of the backbone function, alos the 
number of the
+        backbone_prediction_outputs of the loss function. See 
`relax.transform.AppendLoss`.
+
+        Should be provided when calling `super().__init__` in constructor 
functions of subclasses.
+
+        For example, `CrossEntropyLoss` requires one backbone prediction 
output; `MarginRankingLoss`
+        requires two backbone prediction outputs.
+
+    reduction : Literal["mean", "sum", "none"]
+        The reduction method to apply to output. Can be "mean", "sum" or 
"none".
+
+        none : no reduction will be applied,
+        mean : the sum of the output will be divided by the batch_size,
+        sum : the output will be summed.
+    """
+
+    _valid_reductions = ["mean", "sum", "none"]
+
+    def __init__(
+        self,
+        loss_name: str,
+        num_backbone_outputs: int,
+        reduction: Literal["mean", "sum", "none"] = "mean",
+    ) -> None:
+        self._loss_name = loss_name
+        self._reduction = reduction
+        self._num_backbone_outputs = num_backbone_outputs
+
+        if self._reduction not in self._valid_reductions:
+            raise ValueError("Reduction can only be one of these values: ", 
self._valid_reductions)
+
+    @property
+    def num_backbone_outputs(self) -> int:
+        """Get the number of number of the outputs of the backbone function."""
+        return self._num_backbone_outputs
+
+    def _with_reduction(self, expr: Expr) -> Expr:
+        """Add a reduction to the final loss.
+
+        Parameters
+        ----------
+        expr : Expr
+            The loss expr.
+
+        Returns
+        -------
+        ret : Expr
+            The reduced result.
+        """
+        if self._reduction == "sum":
+            expr = sum(expr)
+        elif self._reduction == "mean":
+            expr = mean(expr)
+        elif self._reduction != "none":
+            raise ValueError("Reduction can only be one of these values: ", 
self._valid_reductions)
+        return expr
+
+
+class L1Loss(Loss):
+    r"""Mean element-wise absolute value difference.
+
+    Parameters
+    ----------
+    reduction : Literal["mean", "sum", "none"]
+        The reduction method to apply to output. Can be "mean", "sum" or 
"none".
+
+        none : no reduction will be applied,
+        mean : the sum of the output will be divided by the batch_size,
+        sum : the output will be summed.
+    """
+
+    def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean") -> 
None:
+        super().__init__("l1_loss", 1, reduction)
+
+    def __call__(
+        self,
+        predictions: Union[Var, StructInfo],
+        targets: Union[Var, StructInfo],
+    ) -> Function:
+        """Get the relax function of L1Loss. If the parameters are
+        struct info, it will create corresponding variables.
+
+        Parameters
+        ----------
+        predictions : Union[Var, StructInfo]
+            The predictions of the model in the calculation of loss.
+        targets : Union[Var, StructInfo]
+            The ground truth in the calculation of loss.
+
+        Returns
+        -------
+        The relax function of L1Loss with the loss name as its global symbol.
+        """
+        bb = BlockBuilder()
+
+        predictions = _create_param_var(predictions, "predictions")
+        targets = _create_param_var(targets, "targets")
+
+        with bb.function(self._loss_name, [predictions, targets]):
+            with bb.dataflow():
+                lv = abs(subtract(predictions, targets))
+                loss = bb.emit_output(self._with_reduction(lv))
+            bb.emit_func_output(loss)
+
+        return bb.get()[self._loss_name]
+
+
+class MSELoss(Loss):
+    r"""Measures the element-wise mean squared error.
+
+    Parameters
+    ----------
+    reduction : Literal["mean", "sum", "none"]
+        The reduction method to apply to output. Can be "mean", "sum" or 
"none".
+
+        none : no reduction will be applied,
+        mean : the sum of the output will be divided by the batch_size,
+        sum : the output will be summed.
+    """
+
+    def __init__(self, reduction: Literal["mean", "sum", "none"] = "mean") -> 
None:
+        super().__init__("mse_loss", 1, reduction)
+
+    def __call__(
+        self,
+        predictions: Union[Var, StructInfo],
+        targets: Union[Var, StructInfo],
+    ) -> Function:
+        """Get the relax function of MSELoss. If the parameters are
+        struct info, it will create corresponding variables.
+
+        Parameters
+        ----------
+        predictions : Union[Var, StructInfo]
+            The predictions of the model in the calculation of loss.
+        targets : Union[Var, StructInfo]
+            The ground truth in the calculation of loss.
+
+        Returns
+        -------
+        The relax function of MSELoss with the loss name as its global symbol.
+        """
+        bb = BlockBuilder()
+
+        predictions = _create_param_var(predictions, "predictions")
+        targets = _create_param_var(targets, "targets")
+
+        with bb.function(self._loss_name, [predictions, targets]):
+            with bb.dataflow():
+                lv = subtract(predictions, targets)
+                lv = multiply(lv, lv)
+                loss = bb.emit_output(self._with_reduction(lv))
+            bb.emit_func_output(loss)
+
+        return bb.get()[self._loss_name]
+
+
+class CrossEntropyLoss(Loss):
+    r"""CrossEntropyLoss. It is a combination of a log_softmax computation and 
a nll_loss.
+
+    Parameters
+    ----------
+    reduction : Literal["mean", "sum", "none"]
+        The reduction method to apply to output. Can be "mean", "sum" or 
"none".
+
+        none : no reduction will be applied,
+        mean : the sum of the output will be divided by the batch_size,
+        sum : the output will be summed.
+
+    ignore_index : int
+        Specifies a target value that is ignored and does not contribute to 
the input gradient.
+    """
+
+    ignore_index: int
+
+    def __init__(
+        self,
+        reduction: Literal["mean", "sum", "none"] = "mean",
+        ignore_index: int = -100,
+    ) -> None:
+        super().__init__("cross_entropy_loss", 1, reduction)
+        self.ignore_index = ignore_index
+
+    def __call__(
+        self,
+        predictions: Union[Var, StructInfo],
+        targets: Union[Var, StructInfo],
+        weights: Optional[Union[Var, StructInfo]] = None,
+    ) -> Function:
+        """Get the relax function of CrossEntropyLoss. If the parameters are
+        struct info, it will create corresponding variables.
+
+        Parameters
+        ----------
+        predictions : Union[Var, StructInfo]
+            The predictions of the model in the calculation of loss.
+
+        targets : Union[Var, StructInfo]
+            The ground truth in the calculation of loss.
+
+        weights : Optional[Union[Var, StructInfo]]
+            a manual rescaling weight given to each class. It has to be a 
Tensor of size C.
+
+        Returns
+        -------
+        The relax function of CrossEntropyLoss with the loss name as its 
global symbol.
+        """
+        bb = BlockBuilder()
+
+        predictions = _create_param_var(predictions, "predictions")
+        targets = _create_param_var(targets, "targets")
+
+        arg_list = [predictions, targets]
+        if weights:
+            weights = _create_param_var(weights, "weights")
+            arg_list.append(weights)
+
+        with bb.function(self._loss_name, arg_list):
+            with bb.dataflow():
+                logits = bb.emit(log_softmax(predictions))
+                loss = bb.emit_output(
+                    nll_loss(logits, targets, weights, self._reduction, 
self.ignore_index)
+                )
+            bb.emit_func_output(loss)
+
+        return bb.get()[self._loss_name]
diff --git a/python/tvm/relax/training/utils.py 
b/python/tvm/relax/training/utils.py
new file mode 100644
index 0000000000..2944a203a9
--- /dev/null
+++ b/python/tvm/relax/training/utils.py
@@ -0,0 +1,155 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Utility functions for relax training."""
+
+from typing import Optional
+
+import tvm
+from ..expr import Function
+from . import _ffi_api
+
+
+def AppendLoss(
+    func_name: str,
+    loss_function: Function,
+    num_backbone_outputs: int = 1,
+    new_func_name: Optional[str] = None,
+) -> tvm.ir.transform.Pass:
+    """Append the loss function to the backbone function specified by 
`func_name`. Generally, the
+    loss function is generated by instances of `relax.training.Loss`.
+
+    The backbone function and the loss function should satisfy a few 
restrictions:
+    - Both backbone and loss should contain exactly one DataflowBlock.
+    - Backbone should return either one Var, or a tuple of Vars
+    - Loss should return a scalar(0-dim Tensor) Var
+
+    They should be like:
+
+    .. code-block:: python
+        @R.function
+        def backbone(input_instances, parameters, states):
+            with R.dataflow():
+                # Predicts the result
+                ...
+            return backbone_result, updated_states
+
+        @R.function
+        def loss(backbone_result, targets):
+            with R.dataflow():
+                # calculate the loss between backbone_result and targets
+                ...
+            # loss should be a scalar Var
+            return loss
+
+    Here each of input_instances, parameters, states, backbone_result and 
updated_states can
+    denote a number of parameters.
+
+    `states` denote the states that we need to maintain as the training 
process proceeds, such as
+    the running mean and the running var of the batch norm operator. The 
updated states is returned
+    in `updated_states`. States can be empty if there is no state that needs 
to be updated.
+
+    The appended result contains only one DataflowBlock containing all 
bindings in backbone and
+    loss. It will be like:
+
+    .. code-block:: python
+        @R.function
+        def backbone_loss(input_instances, parameters, states, targets):
+            with R.dataflow():
+                # all bindings in backbone and loss
+                ...
+            return loss, updated_states
+
+    Parameters
+    ----------
+    func_name : str
+        The name of the backbone function in the IRModule.
+
+    loss_func : Function
+        The loss function.
+
+    num_backbone_outputs : int
+        Specify the number of `prediction_outputs` of the backbone function. 
Default: 1.
+
+    new_func_name : Optional[str]
+        Specify the name of the appended result. If is is not specified, the 
name will be
+        `func_name + "_loss"`.
+
+    Returns
+    -------
+    ret : Function
+        The result function.
+
+    Examples
+    --------
+    .. code-block:: python
+        @I.ir_module
+        class Module
+            @R.function
+            def predict(x: R.Tensor((2, 4), "float32"), y: R.Tensor((2, 4), 
"float32")):
+                with R.dataflow():
+                    out = R.add(x, y)
+                    R.output(out)
+                return out
+
+        @R.function
+        def loss(predictions: R.Tensor((2, 4), "float32"), labels: 
R.Tensor((2, 4), "float32")):
+            with R.dataflow():
+                lv = R.subtract(predictions, labels)
+                lv1 = R.multiply(lv, lv)
+                gv = R.sum(lv1)
+                R.output(gv)
+            return gv
+
+        expected = AppendLoss("predict", loss)(Module)
+        expected.show()
+
+    Will get
+
+    .. code-block:: python
+        @I.ir_module
+        class Module
+            @R.function
+            def predict(x: R.Tensor((2, 4), "float32"), y: R.Tensor((2, 4), 
"float32")):
+                with R.dataflow():
+                    out = R.add(x, y)
+                    R.output(out)
+                return out
+
+            @R.function
+            def predict_loss(x: R.Tensor((2, 4), "float32"), y: R.Tensor((2, 
4), "float32"),
+                             labels: R.Tensor((2, 4), "float32")) -> 
R.Tensor((), "float32"):
+                with R.dataflow():
+                    out: R.Tensor((2, 4), "float32") = R.add(x, y)
+                    lv: R.Tensor((2, 4), "float32") = R.subtract(out, labels)
+                    lv1: R.Tensor((2, 4), "float32") = R.multiply(lv, lv)
+                    gv: R.Tensor((), "float32") = R.sum(lv1)
+                    R.output(gv)
+                return gv
+
+    Notes
+    -----
+    This util can be replaced if we have inline pass. It is equivalent to 
inline a tail call in
+    some sense.
+    """
+
+    return _ffi_api.AppendLoss(  # type: ignore
+        func_name,
+        loss_function,
+        num_backbone_outputs,
+        new_func_name,
+    )
diff --git a/src/relax/training/utils.cc b/src/relax/training/utils.cc
new file mode 100644
index 0000000000..0a88e4569f
--- /dev/null
+++ b/src/relax/training/utils.cc
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/transform/append_loss.cc
+ * \brief A tool to append the loss function to the backbone function in an 
IRModule.
+ */
+
+#include "utils.h"
+
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+
+#include <unordered_set>
+
+#include "../transform/utils.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Append the loss function to the backbone function in an IRModule.*/
+class AppendLossMutator : private ExprMutator {
+ public:
+  static IRModule Transform(IRModule mod, String func_name, Function 
loss_function,
+                            int num_backbone_outputs, Optional<String> 
new_func_name) {
+    auto* old_func = mod->Lookup(func_name).as<FunctionNode>();
+    CHECK(old_func) << func_name << "is not a Relax Function";
+
+    // functions should be copied to satisfy the well-formed check
+    Function new_func = CopyWithNewVars(GetRef<Function>(old_func));
+    Function new_loss_func = CopyWithNewVars(loss_function);
+
+    AppendLossMutator mutator(mod, new_loss_func, num_backbone_outputs);
+    auto new_func_transformed = 
Downcast<Function>(mutator.VisitExpr(new_func));
+
+    auto new_module = GetRef<IRModule>(mod.CopyOnWrite());
+    auto new_var = GlobalVar(new_func_name.value_or(func_name + "_loss"));
+    new_module->Add(new_var, new_func_transformed);
+    return new_module;
+  }
+
+ private:
+  AppendLossMutator(const IRModule& module, const Function& loss_function, int 
num_backbone_outputs)
+      : ExprMutator(module),
+        loss_function_(loss_function),
+        num_backbone_outputs_(num_backbone_outputs) {}
+
+  Expr VisitExpr_(const FunctionNode* func) final {
+    CHECK(func->body->IsInstance<SeqExprNode>() && 
loss_function_->body->IsInstance<SeqExprNode>())
+        << "The bodies of the backbone and the loss function must be SeqExpr.";
+
+    // Well-formed checks and setting up class members
+    loss_body_ = Downcast<SeqExpr>(loss_function_->body);
+    CheckLossBody();
+    BackboneReturnToArr(func->body.as<SeqExprNode>()->body);
+    CheckAndRemapBackboneReturn();
+    CheckAndRemapLossParams(loss_function_->params);
+
+    Array<Var> new_params = func->params;
+    new_params.insert(new_params.end(), loss_function_->params.begin() + 
num_backbone_outputs_,
+                      loss_function_->params.end());
+    Expr new_body = this->VisitExpr(func->body);
+
+    return Function(new_params, new_body, NullOpt, func->attrs);
+  }
+
+  Expr VisitExpr_(const SeqExprNode* seq_expr) final {
+    CHECK(seq_expr->blocks.size() == 1 && 
seq_expr->blocks[0]->IsInstance<DataflowBlockNode>())
+        << "Backbone should have only one DataflowBlock";
+
+    auto new_blocks = 
Array<BindingBlock>({this->VisitBindingBlock(seq_expr->blocks[0])});
+    auto ret = Array<Expr>({loss_body_->body});
+    ret.insert(ret.end(), backbone_return_arr_.begin() + num_backbone_outputs_,
+               backbone_return_arr_.end());
+    return SeqExpr(new_blocks, ret.size() == 1 ? ret[0] : Tuple(ret));
+  }
+
+  BindingBlock VisitBindingBlock_(const DataflowBlockNode* block) final {
+    builder_->BeginDataflowBlock();
+    // Emit original bindings.
+    for (const auto& binding : block->bindings) {
+      this->VisitBinding(binding);
+    }
+
+    // Emit bindings in the loss function.
+    for (const Binding& binding : loss_body_->blocks[0]->bindings) {
+      this->VisitBinding(binding);
+    }
+
+    return builder_->EndBlock();
+  }
+
+  /*!
+   * \brief Using VisitExpr to remap the defined variable. This is different 
from the standard
+   * behaviour of VisitVarDef.
+   */
+  Var VisitVarDef(const Var& var) final { return 
Downcast<Var>(this->VisitExpr(var)); }
+
+  /*! \brief Checks the loss function have only one DataflowBlock, and returns 
a scalar Var. */
+  void CheckLossBody() {
+    CHECK(loss_body_->blocks.size() == 1 && 
loss_body_->blocks[0]->IsInstance<DataflowBlockNode>())
+        << "The loss function should have only one DataflowBlock";
+    auto var_node = loss_body_->body.as<VarNode>();
+    CHECK(var_node && IsScalarTensor(GetRef<Var>(var_node)))
+        << "The loss function must return a scalar(0-dim Tensor) Var";
+  }
+
+  /*!
+   * \brief Convert the return value of the backbone to Array<Var>. The 
backbone should return one
+   * or a tuple of Vars.
+   */
+  void BackboneReturnToArr(const Expr& backbone_return) {
+    if (auto* var = backbone_return.as<VarNode>()) {
+      backbone_return_arr_.push_back(GetRef<Var>(var));
+    } else if (auto* tuple = backbone_return.as<TupleNode>()) {
+      for (auto i : tuple->fields) {
+        auto var = i.as<VarNode>();
+        CHECK(var) << "The return value of the backbone should be either a Var 
or a Tuple of Vars";
+        backbone_return_arr_.push_back(GetRef<Var>(var));
+      }
+    } else {
+      LOG(FATAL) << "The return value of the backbone should be either a Var 
or a Tuple of Vars";
+    }
+  }
+
+  /*!
+   * \brief Check the number of elements in loss_func_params is no less than 
num_backbone_outputs,
+   * and the elements in backbone_return_arr_ and loss_func_params have 
matched struct_info. Also
+   * sets up var_remap_ from loss parameter Vars to backbone returned Vars.
+   */
+  void CheckAndRemapLossParams(const Array<Var>& loss_func_params) {
+    static StructuralEqual checker;
+    CHECK(static_cast<int>(loss_func_params.size()) >= num_backbone_outputs_)
+        << "The number of parameters of the loss function is " << 
loss_func_params.size()
+        << ", which is less than the given num_backbone_outputs " << 
num_backbone_outputs_;
+    for (int i = 0; i < num_backbone_outputs_; ++i) {
+      Var loss_param = loss_func_params[i];
+      Var backbone_ret = backbone_return_arr_[i];
+      auto loss_param_sinfo = GetStructInfo(loss_param);
+      auto backbone_ret_sinfo = GetStructInfo(backbone_ret);
+
+      CHECK(checker(backbone_ret_sinfo, loss_param_sinfo))
+          << "The struct info of the " << i
+          << "-th return value of backbone function is: " << backbone_ret_sinfo
+          << " while the corresponding struct info of parameter of loss 
function is "
+          << loss_param_sinfo << ", which is different.";
+
+      this->var_remap_[loss_param->vid] = backbone_ret;
+    }
+  }
+
+  /*!
+   * \brief Check the number of elements in backbone_return_arr_ is no less 
than
+   * num_backbone_outputs. Then remap Vars in backbone return values that 
satisfy these conditions
+   * from Var to DataflowVar:
+   *
+   * 1. Is used in prediction_outputs of the backbone function,
+   * 2. Is not used in other_outputs of the backbone function.
+   *
+   * Because such Vars are no longer the outputs of the new function.
+   */
+  void CheckAndRemapBackboneReturn() {
+    CHECK(static_cast<int>(backbone_return_arr_.size()) >= 
num_backbone_outputs_)
+        << "The number of return values of the backbone function is " << 
backbone_return_arr_.size()
+        << ", which is less than the given num_backbone_outputs " << 
num_backbone_outputs_;
+    std::unordered_set<Var, ObjectPtrHash> other_outputs_var(
+        backbone_return_arr_.begin() + num_backbone_outputs_, 
backbone_return_arr_.end());
+    for (int i = 0; i < num_backbone_outputs_; ++i) {
+      auto var = backbone_return_arr_[i];
+      if (other_outputs_var.count(var) == 0) {
+        auto new_var = DataflowVar(var->vid, GetStructInfo(var), var->span);
+        this->var_remap_[var->vid] = new_var;
+        backbone_return_arr_.Set(i, new_var);
+      }
+    }
+  }
+
+  /*! \brief The loss function. */
+  Function loss_function_;
+  /*! \brief The number of prediction_outputs of the backbone function. */
+  int num_backbone_outputs_;
+  /*! \brief The body of the loss function */
+  SeqExpr loss_body_;
+  /*! \brief The unpacked return values of the backbone. All return values 
should be Vars. */
+  Array<Var> backbone_return_arr_;
+};
+
+namespace transform {
+
+Pass AppendLoss(String func_name, Function loss_function, int 
num_backbone_outputs,
+                Optional<String> new_func_name) {
+  runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = 
[=](IRModule mod,
+                                                                            
PassContext pc) {
+    return relax::AppendLossMutator::Transform(mod, func_name, loss_function, 
num_backbone_outputs,
+                                               new_func_name);
+  };
+  return CreateModulePass(/*pass_function=*/pass_func,
+                          /*opt_level=*/0,
+                          /*pass_name=*/"AppendLoss",
+                          /*required=*/{});
+}
+
+TVM_REGISTER_GLOBAL("relax.training.AppendLoss").set_body_typed(AppendLoss);
+
+}  // namespace transform
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/training/utils.h b/src/relax/training/utils.h
new file mode 100644
index 0000000000..074aedc287
--- /dev/null
+++ b/src/relax/training/utils.h
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/training/utils.h
+ * \brief Utility classes and functions for relax training.
+ */
+#ifndef TVM_RELAX_TRAINING_UTILS_H_
+#define TVM_RELAX_TRAINING_UTILS_H_
+
+#include <tvm/relax/expr.h>
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+namespace transform {
+
+/*!
+ * \brief Append the loss function to the backbone function specified by 
`func_name`. Generally, the
+ * loss function is generated by instances of `relax.training.Loss`.
+ *
+ * The backbone function and the loss function should satisfy a few 
restrictions:
+ * - Both backbone and loss should contain exactly one DataflowBlock.
+ * - Backbone should return either one Var, or a tuple of Vars
+ * - Loss should return a scalar(0-dim Tensor) Var
+ *
+ * The appended result contains only one DataflowBlock containing all bindings 
in backbone and loss.
+ *
+ * \param func_name The name of the backbone function in the IRModule.
+ * \param loss_function The loss function.
+ * \param num_backbone_outputs Specify the number of `prediction_outputs` of 
the backbone function.
+ * Default: 1.
+ * \param new_func_name Specify the name of the appended result. If is is not 
specified, the name
+ * will be `func_name + "_loss"`.
+ * \return The Pass.
+ */
+TVM_DLL Pass AppendLoss(String func_name, Function loss_function, int 
num_backbone_outputs = 1,
+                        Optional<String> new_func_name = NullOpt);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_TRAINING_UTILS_H_
diff --git a/tests/python/relax/test_training_append_loss.py 
b/tests/python/relax/test_training_append_loss.py
new file mode 100644
index 0000000000..be6a2144cf
--- /dev/null
+++ b/tests/python/relax/test_training_append_loss.py
@@ -0,0 +1,327 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm.testing
+from tvm import TVMError
+from tvm.ir.base import assert_structural_equal
+from tvm.script import relax as R, ir as I
+from tvm.relax.training import AppendLoss
+
+
+def test_simple():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")):
+            with R.dataflow():
+                gv0 = x + y
+                R.output(gv0)
+            return gv0
+
+    @R.function
+    def loss(arg1: R.Tensor((3, 3), "float32")):
+        with R.dataflow():
+            gv0 = R.sum(arg1)
+            R.output(gv0)
+        return gv0
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")) -> R.Tensor((3, 3), "float32"):
+            with R.dataflow():
+                gv0: R.Tensor((3, 3), "float32") = R.add(x, y)
+                R.output(gv0)
+            return gv0
+
+        @R.function
+        def main_loss(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")) -> R.Tensor((), "float32"):
+            with R.dataflow():
+                gv0: R.Tensor((3, 3), "float32") = R.add(x, y)
+                gv0_1: R.Tensor((), "float32") = R.sum(gv0, axis=None, 
keepdims=False)
+                R.output(gv0_1)
+            return gv0_1
+    # fmt: on
+
+    After = AppendLoss("main", loss)(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_num_backbone_outputs():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")):
+            with R.dataflow():
+                gv0 = R.sum(x)
+                gv1 = R.sum(y)
+                R.output(gv0, gv1)
+            return gv0, gv1
+
+    @R.function
+    def loss(arg1: R.Tensor((), "float32"), arg2: R.Tensor((), "float32")):
+        with R.dataflow():
+            gv0 = R.add(arg1, arg2)
+            R.output(gv0)
+        return gv0
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tensor((), "float32")):
+            with R.dataflow():
+                gv0: R.Tensor((), "float32") = R.sum(x, axis=None, 
keepdims=False)
+                gv1: R.Tensor((), "float32") = R.sum(y, axis=None, 
keepdims=False)
+                R.output(gv0, gv1)
+            return (gv0, gv1)
+
+        @R.function
+        def main_loss(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")) -> R.Tensor((), "float32"):
+            with R.dataflow():
+                gv0: R.Tensor((), "float32") = R.sum(x, axis=None, 
keepdims=False)
+                gv1: R.Tensor((), "float32") = R.sum(y, axis=None, 
keepdims=False)
+                gv0_1: R.Tensor((), "float32") = R.add(gv0, gv1)
+                R.output(gv0_1)
+            return gv0_1
+    # fmt: on
+
+    After = AppendLoss("main", loss, 2)(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_extra_params():
+    # fmt: off
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                gv0 = R.sum(x)
+                gv1 = R.add(x, x)
+                gv2 = x
+                R.output(gv0, gv1, gv2)
+            return gv0, gv1, gv2
+
+    @R.function
+    def loss(
+        arg1: R.Tensor((), "float32"),
+        arg2: R.Tensor((3, 3), "float32"),
+        arg3: R.Tensor((3, 3), "float32"),
+    ):
+        with R.dataflow():
+            gv0 = R.add(arg2, arg3)
+            gv1 = R.sum(gv0)
+            R.output(gv1)
+        return gv1
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")) -> R.Tuple(R.Tensor((), 
"float32"), R.Tensor((3, 3), "float32"), R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                gv0: R.Tensor((), "float32") = R.sum(x, axis=None, 
keepdims=False)
+                gv1: R.Tensor((3, 3), "float32") = R.add(x, x)
+                gv2: R.Tensor((3, 3), "float32") = x
+                R.output(gv0, gv1, gv2)
+            return (gv0, gv1, gv2)
+
+        @R.function
+        def main_loss(x: R.Tensor((3, 3), "float32"), arg3: R.Tensor((3, 3), 
"float32")) -> R.Tuple(R.Tensor((), "float32"), R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                gv0: R.Tensor((), "float32") = R.sum(x, axis=None, 
keepdims=False)
+                gv1: R.Tensor((3, 3), "float32") = R.add(x, x)
+                gv2: R.Tensor((3, 3), "float32") = x
+                gv0_1: R.Tensor((3, 3), "float32") = R.add(gv1, arg3)
+                gv1_1: R.Tensor((), "float32") = R.sum(gv0_1, axis=None, 
keepdims=False)
+                R.output(gv2, gv1_1)
+            return (gv1_1, gv2)
+    # fmt: on
+
+    After = AppendLoss("main", loss, 2)(Before)
+    assert_structural_equal(After, Expected)
+
+
+def test_error_return_value_vs_parameter():
+    # StructInfo not match
+    # fmt: off
+    @I.ir_module
+    class Module1:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")):
+            with R.dataflow():
+                gv0 = R.sum(x)
+                gv1 = R.sum(y)
+                R.output(gv0, gv1)
+            return gv0, gv1
+
+    @R.function
+    def loss1(arg1: R.Tensor((), "float64"), arg2: R.Tensor((), "float64")):
+        with R.dataflow():
+            gv0 = R.add(arg1, arg2)
+            R.output(gv0)
+        return gv0
+    # fmt: on
+
+    with pytest.raises(TVMError):
+        AppendLoss("main", loss1, 2)(Module1)
+
+    # The numbers of backbone return value and loss parameter are not enough
+    # fmt: off
+    @I.ir_module
+    class Module2:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")):
+            with R.dataflow():
+                gv0 = x + y
+                R.output(gv0)
+            return gv0
+
+    @R.function
+    def loss2(arg1: R.Tensor((3, 3), "float32")):
+        with R.dataflow():
+            gv0 = R.sum(arg1)
+            R.output(gv0)
+        return gv0
+    # fmt: on
+
+    with pytest.raises(TVMError):
+        AppendLoss("main", loss2, 2)(Module2)
+
+    # Backbone returns nested tuple
+    # fmt: off
+    @I.ir_module
+    class Module3:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32"), y: R.Tensor((3, 3), 
"float32")):
+            with R.dataflow():
+                gv0 = x
+                gv1 = y
+                gv2 = x + y
+                R.output(gv0, gv1, gv2)
+            return gv0, (gv1, gv2)
+
+    @R.function
+    def loss3(arg1: R.Tensor((3, 3), "float32")):
+        with R.dataflow():
+            gv0 = R.sum(arg1)
+            R.output(gv0)
+        return gv0
+    # fmt: on
+
+    with pytest.raises(TVMError):
+        AppendLoss("main", loss3, 1)(Module3)
+
+
+def test_error_more_blocks():
+    # backbone more than one blocks
+    # fmt: off
+    @I.ir_module
+    class Module1:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                gv0 = x
+                R.output(gv0)
+            gv1 = gv0
+            return gv1
+
+    @R.function
+    def loss1(arg: R.Tensor((3, 3), "float32")):
+        with R.dataflow():
+            gv = R.sum(arg)
+            R.output(gv)
+        return gv
+    # fmt: on
+
+    with pytest.raises(TVMError):
+        AppendLoss("main", loss1)(Module1)
+
+    # loss more than one blocks
+    # fmt: off
+    @I.ir_module
+    class Module2:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                gv0 = x
+                R.output(gv0)
+            return gv0
+
+    @R.function
+    def loss2(arg: R.Tensor((3, 3), "float32")):
+        with R.dataflow():
+            gv = R.sum(arg)
+            R.output(gv)
+        gv1 = gv
+        return gv1
+    # fmt: on
+
+    with pytest.raises(TVMError):
+        AppendLoss("main", loss2)(Module2)
+
+
+def test_loss_return_value():
+    # loss returns non-scalar var
+    # fmt: off
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                gv0 = x
+                R.output(gv0)
+            return gv0
+
+    @R.function
+    def loss(arg1: R.Tensor((3, 3), "float32")):
+        with R.dataflow():
+            gv0 = arg1
+            R.output(gv0)
+        return gv0
+    # fmt: on
+
+    with pytest.raises(TVMError):
+        AppendLoss("main", loss)(Module)
+
+    # loss returns tuple
+    # fmt: off
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor((3, 3), "float32")):
+            with R.dataflow():
+                gv0 = x
+                R.output(gv0)
+            return gv0
+
+    @R.function
+    def loss(arg1: R.Tensor((3, 3), "float32")):
+        with R.dataflow():
+            gv0 = R.sum(arg1)
+            gv1 = gv0 + gv0
+            R.output(gv0, gv1)
+        return gv0, gv1
+    # fmt: on
+
+    with pytest.raises(TVMError):
+        AppendLoss("main", loss)(Module)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_training_loss.py 
b/tests/python/relax/test_training_loss.py
new file mode 100644
index 0000000000..68d59dca05
--- /dev/null
+++ b/tests/python/relax/test_training_loss.py
@@ -0,0 +1,212 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm.testing
+from tvm import relax
+from tvm.ir.base import assert_structural_equal
+from tvm.script import relax as R, ir as I
+
+
[email protected]_module
+class Module:
+    @R.function
+    def forward(
+        x: R.Tensor((2, 4), "float32"),
+        w: R.Tensor((4, 4), "float32"),
+        b: R.Tensor((2, 4), "float32"),
+    ) -> R.Tensor((2, 4), "float32"):
+        with R.dataflow():
+            lv: R.Tensor((2, 4), "float32") = R.matmul(x, w)
+            out: R.Tensor((2, 4), "float32") = R.add(lv, b)
+            R.output(out)
+        return out
+
+
+def test_l1_loss():
+    N = 3
+    C = 5
+    predictions = relax.TensorStructInfo((N, C), "float32")
+    targets = relax.TensorStructInfo((N, C), "float32")
+    l1_loss = relax.training.loss.L1Loss()
+
+    @R.function
+    def expected(
+        predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), 
"float32")
+    ) -> R.Tensor((), "float32"):
+        with R.dataflow():
+            lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets)
+            lv1: R.Tensor((3, 5), "float32") = R.abs(lv)
+            gv: R.Tensor((), "float32") = R.mean(lv1, axis=None, 
keepdims=False)
+            R.output(gv)
+        return gv
+
+    assert_structural_equal(l1_loss(predictions, targets), expected)
+
+
+def test_l1_loss_append():
+    s = Module["forward"].ret_struct_info
+    l1_loss = relax.training.loss.L1Loss(reduction="sum")
+    After = relax.training.AppendLoss("forward", l1_loss(s, s), 
l1_loss.num_backbone_outputs)(
+        Module
+    )
+
+    @R.function
+    def expected(
+        x: R.Tensor((2, 4), "float32"),
+        w: R.Tensor((4, 4), "float32"),
+        b: R.Tensor((2, 4), "float32"),
+        targets: R.Tensor((2, 4), "float32"),
+    ) -> R.Tensor((), "float32"):
+        with R.dataflow():
+            lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="")
+            out: R.Tensor((2, 4), "float32") = R.add(lv, b)
+            lv1: R.Tensor((2, 4), "float32") = R.subtract(out, targets)
+            lv11: R.Tensor((2, 4), "float32") = R.abs(lv1)
+            gv: R.Tensor((), "float32") = R.sum(lv11, axis=None, 
keepdims=False)
+            R.output(gv)
+        return gv
+
+    assert_structural_equal(After["forward_loss"], expected)
+
+
+def test_mse_loss():
+    N = 3
+    C = 5
+    predictions = relax.TensorStructInfo((N, C), "float32")
+    targets = relax.TensorStructInfo((N, C), "float32")
+    mse_loss = relax.training.loss.MSELoss()
+
+    @R.function
+    def expected(
+        predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3, 5), 
"float32")
+    ) -> R.Tensor((), "float32"):
+        with R.dataflow():
+            lv: R.Tensor((3, 5), "float32") = R.subtract(predictions, targets)
+            lv1: R.Tensor((3, 5), "float32") = R.multiply(lv, lv)
+            gv: R.Tensor((), "float32") = R.mean(lv1, axis=None, 
keepdims=False)
+            R.output(gv)
+        return gv
+
+    assert_structural_equal(mse_loss(predictions, targets), expected)
+
+
+def test_mse_loss_append():
+    s = Module["forward"].ret_struct_info
+    mse_loss = relax.training.loss.MSELoss(reduction="sum")
+    After = relax.training.AppendLoss("forward", mse_loss(s, s), 
mse_loss.num_backbone_outputs)(
+        Module
+    )
+
+    @R.function
+    def expected(
+        x: R.Tensor((2, 4), "float32"),
+        w: R.Tensor((4, 4), "float32"),
+        b: R.Tensor((2, 4), "float32"),
+        targets: R.Tensor((2, 4), "float32"),
+    ) -> R.Tensor((), "float32"):
+        with R.dataflow():
+            lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="")
+            out: R.Tensor((2, 4), "float32") = R.add(lv, b)
+            lv1: R.Tensor((2, 4), "float32") = R.subtract(out, targets)
+            lv11: R.Tensor((2, 4), "float32") = R.multiply(lv1, lv1)
+            gv: R.Tensor((), "float32") = R.sum(lv11, axis=None, 
keepdims=False)
+            R.output(gv)
+        return gv
+
+    assert_structural_equal(After["forward_loss"], expected)
+
+
+def test_cross_entropy_loss():
+    N = 3
+    C = 5
+    predictions = relax.TensorStructInfo((N, C), "float32")
+    targets = relax.TensorStructInfo((N,), "int64")
+    weights = relax.TensorStructInfo((C,), "float32")
+    cross_entropy_loss = relax.training.loss.CrossEntropyLoss(reduction="sum", 
ignore_index=1)
+
+    @R.function
+    def expected(
+        predictions: R.Tensor((3, 5), "float32"),
+        targets: R.Tensor((3,), "int64"),
+        weights: R.Tensor((5,), "float32"),
+    ) -> R.Tensor((), "float32"):
+        with R.dataflow():
+            lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, 
axis=-1)
+            gv: R.Tensor((), "float32") = R.nn.nll_loss(
+                lv, targets, weights, reduction="sum", ignore_index=1
+            )
+            R.output(gv)
+        return gv
+
+    assert_structural_equal(cross_entropy_loss(predictions, targets, weights), 
expected)
+
+
+def test_cross_entropy_loss_without_weights():
+    N = 3
+    C = 5
+    predictions = relax.TensorStructInfo((N, C), "float32")
+    targets = relax.TensorStructInfo((N,), "int64")
+    cross_entropy_loss = relax.training.loss.CrossEntropyLoss()
+
+    @R.function
+    def expected(
+        predictions: R.Tensor((3, 5), "float32"), targets: R.Tensor((3,), 
"int64")
+    ) -> R.Tensor((), "float32"):
+        with R.dataflow():
+            lv: R.Tensor((3, 5), "float32") = R.nn.log_softmax(predictions, 
axis=-1)
+            gv: R.Tensor((), "float32") = R.nn.nll_loss(
+                lv, targets, reduction="mean", ignore_index=-100
+            )
+            R.output(gv)
+        return gv
+
+    assert_structural_equal(cross_entropy_loss(predictions, targets), expected)
+
+
+def test_cross_entropy_loss_append():
+    s = Module["forward"].ret_struct_info
+    N = s.shape[0]
+    C = s.shape[1]
+    targets = relax.TensorStructInfo((N,), "int64")
+    weights = relax.TensorStructInfo((C,), "float32")
+    cross_entropy_loss = relax.training.loss.CrossEntropyLoss(reduction="sum", 
ignore_index=1)
+    After = relax.training.AppendLoss(
+        "forward", cross_entropy_loss(s, targets, weights), 
cross_entropy_loss.num_backbone_outputs
+    )(Module)
+
+    @R.function
+    def expected(
+        x: R.Tensor((2, 4), "float32"),
+        w: R.Tensor((4, 4), "float32"),
+        b: R.Tensor((2, 4), "float32"),
+        targets: R.Tensor((2,), "int64"),
+        weights: R.Tensor((4,), "float32"),
+    ) -> R.Tensor((), "float32"):
+        with R.dataflow():
+            lv: R.Tensor((2, 4), "float32") = R.matmul(x, w, out_dtype="")
+            out: R.Tensor((2, 4), "float32") = R.add(lv, b)
+            lv1: R.Tensor((2, 4), "float32") = R.nn.log_softmax(out, axis=-1)
+            gv: R.Tensor((), "float32") = R.nn.nll_loss(
+                lv1, targets, weights, reduction="sum", ignore_index=1
+            )
+            R.output(gv)
+        return gv
+
+    assert_structural_equal(After["forward_loss"], expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to