This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 7d1a101e24 [Unity][Frontend][NN] Add Timesteps layer to NN Module API 
(#15603)
7d1a101e24 is described below

commit 7d1a101e242bd1c0bb6ac2151528ef998f94c5ad
Author: Josh Fromm <[email protected]>
AuthorDate: Thu Aug 24 23:35:58 2023 -0700

    [Unity][Frontend][NN] Add Timesteps layer to NN Module API (#15603)
    
    * Add support for relax.nn.pad
    
    * Formatting
    
    * Add nn.pad testing
    
    * Add support for Timesteps layer
    
    * Simplify arithmetic
    
    * Format
    
    * Bump ci
    
    * Format
    
    * Add TimestepEmbedding layer
    
    * Add TimestepEmbedding tests
    
    * Fix cast style
    
    * Add missing docstring
---
 include/tvm/relax/attrs/nn.h                       |  20 +++-
 python/tvm/relax/frontend/nn/modules.py            |  97 ++++++++++++++++++++
 python/tvm/relax/frontend/nn/op.py                 |  67 +++++++++++++-
 python/tvm/relax/op/nn/nn.py                       |  31 ++++++-
 python/tvm/relax/transform/legalize_ops/nn.py      |  16 ++++
 src/relax/op/nn/nn.cc                              |  46 ++++++++++
 tests/python/relax/test_frontend_nn_modules.py     | 102 +++++++++++++++++++++
 tests/python/relax/test_frontend_nn_op.py          |  39 ++++++++
 tests/python/relax/test_op_nn.py                   |  26 ++++++
 .../python/relax/test_transform_legalize_ops_nn.py |  42 +++++++++
 10 files changed, 480 insertions(+), 6 deletions(-)

diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 2dc610f654..0d895dccb1 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -375,7 +375,7 @@ struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
   }
 };  // struct DropoutAttrs
 
-/*! \brief Attributes used in dropout operator */
+/*! \brief Attributes used in Attention operator */
 struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> {
   Optional<FloatImm> scale;
   Optional<String> causal_mask;
@@ -388,6 +388,24 @@ struct AttentionAttrs : public 
tvm::AttrsNode<AttentionAttrs> {
   }
 };  // struct AttentionAttrs
 
+/*! \brief Attributes used for the padding operator */
+struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
+  Array<Integer> pad_width;
+  tvm::String pad_mode;
+
+  TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
+    TVM_ATTR_FIELD(pad_width).describe(
+        "Number of values padded to the edges of each axis, "
+        "in the format of (before_1, after_1, ..., before_N, after_N)");
+    TVM_ATTR_FIELD(pad_mode)
+        .set_default("constant")
+        .describe(
+            "Padding type to use. \"constant\" pads with constant_value, "
+            "\"edge\" pads using the edge values of the input array, "
+            "\"reflect\" pads by reflecting values with respect to the 
edges.");
+  }
+};
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/python/tvm/relax/frontend/nn/modules.py 
b/python/tvm/relax/frontend/nn/modules.py
index cd94c23115..fde18473ee 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -69,6 +69,15 @@ def _print(_, array: NDArray) -> None:
     print(f"effect.print: shape = {array.shape}, dtype = {array.dtype}, data 
=\n{array}")
 
 
+class SiLU(Module):
+    """
+    Module for SiLU activation layer.
+    """
+
+    def forward(self, x: Tensor):
+        return op.silu(x)
+
+
 class Linear(Module):
     """
     Module for linear layer.
@@ -524,3 +533,91 @@ class Embedding(Module):
             ),
             shape=[*x.shape, self.dim],  # TODO(@junrushao): revisit and 
remove self.dim
         )
+
+
+class TimestepEmbedding(Module):
+    """
+    Module for HF TimestepEmbedding layer.
+    """
+
+    def __init__(
+        self,
+        in_channels: int,
+        time_embed_dim: int,
+        act_fn: str = "silu",
+        out_dim: int = None,
+        post_act_fn: Optional[str] = None,
+        cond_proj_dim: Optional[int] = None,
+    ):
+        self.linear_1 = Linear(in_channels, time_embed_dim)
+
+        if cond_proj_dim is not None:
+            self.cond_proj = Linear(cond_proj_dim, in_channels, bias=False)
+        else:
+            self.cond_proj = None
+
+        assert act_fn == "silu", "Only SiLU activations are supported."
+        self.act = SiLU()
+
+        if out_dim is not None:
+            time_embed_dim_out = out_dim
+        else:
+            time_embed_dim_out = time_embed_dim
+
+        self.linear_2 = Linear(time_embed_dim, time_embed_dim_out)
+
+        if post_act_fn is None:
+            self.post_act = None
+        else:
+            assert self.post_act == "silu", "Only SiLU post-activation 
supported."
+            self.post_act = SiLU()
+
+    def forward(self, sample: Tensor, condition: Optional[Tensor] = None):
+        """
+        Forward method for TimestepEmbedding layer.
+
+        Parameters
+        ----------
+        sample : Tensor
+            The input timestep that should be looked up.
+        condition : Optional[Tensor]
+            Optional additional projection matrix.
+
+        Returns
+        -------
+        ret : Tensor
+            The resulting embedding lookup for the input sample.
+        """
+        if condition is not None:
+            sample = sample + self.cond_proj(condition)
+        sample = self.linear_1(sample)
+
+        if self.act is not None:
+            sample = self.act(sample)
+
+        sample = self.linear_2(sample)
+
+        if self.post_act is not None:
+            sample = self.post_act(sample)
+        return sample
+
+
+class Timesteps(Module):
+    """
+    Module for HF timesteps layer.
+    """
+
+    def __init__(
+        self, num_channels: int, flip_sin_to_cos: bool = False, 
downscale_freq_shift: float = 1
+    ):
+        self.num_channels = num_channels
+        self.flip_sin_to_cos = flip_sin_to_cos
+        self.downscale_freq_shift = downscale_freq_shift
+
+    def forward(self, x: Tensor):
+        return op.get_timestep_embedding(
+            x,
+            embedding_dim=self.num_channels,
+            flip_sin_to_cos=self.flip_sin_to_cos,
+            downscale_freq_shift=self.downscale_freq_shift,
+        )
diff --git a/python/tvm/relax/frontend/nn/op.py 
b/python/tvm/relax/frontend/nn/op.py
index e5485d54c7..b3959cd95f 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: disable=too-many-lines,invalid-name,protected-access
 """nn.Tensor operators."""
+import math
 from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
 
 from tvm import tir as _tir
@@ -859,12 +860,10 @@ def full(
     result : Tensor
         The result tensor.
     """
-    from tvm import relax  # pylint: disable=import-outside-toplevel
-
     if isinstance(fill_value, (_tir.FloatImm, _tir.IntImm)):
-        fill_value = relax.const(fill_value.value, dtype=dtype)
+        fill_value = rx.const(fill_value.value, dtype=dtype)
     elif isinstance(fill_value, (int, float)):
-        fill_value = relax.const(fill_value, dtype=dtype)
+        fill_value = rx.const(fill_value, dtype=dtype)
     else:
         fill_value = fill_value._expr
     return _wrap_nested(_op.full(shape, fill_value, dtype), name)
@@ -896,6 +895,66 @@ def zeros(
     return _wrap_nested(_op.zeros(shape, dtype), name)
 
 
+def get_timestep_embedding(
+    x: Tensor,
+    embedding_dim: int,
+    flip_sin_to_cos: bool = False,
+    downscale_freq_shift: float = 1,
+    scale: float = 1,
+    max_period: int = 10000,
+    name: str = "get_timestep_embedding",
+) -> Tensor:
+    """
+    Timestep calculation as described in Denoising Diffusion Probabilistic 
Models.
+
+    Parameters
+    ----------
+    x : Tensor
+        A 1-D Tensor of N indices.
+    embedding_dim : int
+        The dimension of the output.
+    flip_sin_to_cos : bool
+        If True, change the order of sine and cosine embeddings.
+    downscale_freq_shift : float
+        Adjusts the frequency of the sinusoidal sampling.
+    scale : float
+        Weight adjustment for embedding magnitude.
+    max_period : int
+        Controls the minimum frequency of the embeddings.
+    name : str
+        The name to label this operator with.
+
+    Returns
+    -------
+    result : Tensor
+        [N x dim] Tensor of positional embeddings.
+    """
+    timesteps = _op.astype(x._expr, "float32")
+
+    half_dim = embedding_dim // 2
+    exponent = rx.const(-math.log(max_period), "float32") * _op.arange(
+        start=0, end=half_dim, dtype="float32"
+    )
+    exponent = exponent / (rx.const(half_dim - downscale_freq_shift, 
"float32"))
+
+    emb = _op.exp(exponent)
+    emb = _op.expand_dims(timesteps, 1) * _op.expand_dims(emb, 0)
+    # Scale embeddings
+    if scale != 1:
+        emb = rx.const(scale, "float32") * emb
+
+    # Concat sine and cosine embeddings.
+    if flip_sin_to_cos:
+        emb = _op.concat([_op.cos(emb), _op.sin(emb)], axis=-1)
+    else:
+        emb = _op.concat([_op.sin(emb), _op.cos(emb)], axis=-1)
+
+    # Zero pad
+    if embedding_dim % 2 == 1:
+        emb = _op.nn.pad(emb, (0, 1, 0, 0))
+    return _wrap_nested(emb, name)
+
+
 def tensor_expr_op(
     tensor_expr_func: Callable,
     name_hint: str,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 1a4c3cceae..cb2718a654 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -20,7 +20,7 @@ from typing import List, Optional, Tuple, Union
 from tvm import DataType
 from tvm.tir import FloatImm
 
-from ...expr import Expr
+from ...expr import Expr, const
 from . import _ffi_api
 
 
@@ -413,6 +413,35 @@ def conv2d_transpose(
     )
 
 
+def pad(data, pad_width, pad_value=0, pad_mode="constant"):
+    r"""Padding
+
+    This operator takes in a tensor and pads each axis by the specified
+    widths using the specified value.
+
+    Parameters
+    ----------
+    data: relax.Expr
+        The input data to the operator
+    pad_width: tuple of <tuple of <int>>, required
+        Number of values padded to the edges of each axis, in the format
+        of ((before_1, after_1), ..., (before_N, after_N))
+    pad_value: float
+        The value used for padding
+    pad_mode: 'constant', 'edge', 'reflect'
+        'constant' pads with constant_value pad_value
+        'edge' pads using the edge values of the input array
+        'reflect' pads by reflecting values with respect to the edge
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if not isinstance(pad_value, Expr):
+        pad_value = const(pad_value)
+    return _ffi_api.pad(data, pad_width, pad_value, pad_mode)
+
+
 def max_pool2d(
     data: Expr,
     pool_size: Union[int, Tuple[int, int]] = (1, 1),
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index 562b497cb2..6b2f25d0a2 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -184,6 +184,22 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> 
Expr:
     )
 
 
+@register_legalize("relax.nn.pad")
+def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
+    # Unpack pad_width into two separate lists for topi.
+    pad_widths = call.attrs.pad_width
+    pad_before = pad_widths[::2]
+    pad_after = pad_widths[1::2]
+    return bb.call_te(
+        topi.nn.pad,
+        call.args[0],
+        pad_before=pad_before,
+        pad_after=pad_after,
+        pad_value=float(call.args[1].data.numpy()),
+        primfunc_name_hint="pad",
+    )
+
+
 @register_legalize("relax.nn.max_pool2d")
 def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr:
     if call.attrs.out_layout != call.attrs.layout:
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index cbd71da848..2e73297e2d 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -123,6 +123,52 @@ TVM_REGISTER_OP("relax.nn.log_softmax")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSoftmax)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.nn.pad */
+TVM_REGISTER_NODE_TYPE(PadAttrs);
+
+Expr pad(Expr data, Array<Integer> pad_width, Expr pad_value, String pad_mode) 
{
+  auto attrs = make_object<PadAttrs>();
+  attrs->pad_width = std::move(pad_width);
+  attrs->pad_mode = std::move(pad_mode);
+  static const Op& op = Op::Get("relax.nn.pad");
+  return Call(op, {data, pad_value}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.pad").set_body_typed(pad);
+
+StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) {
+  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+  const auto* attrs = call->attrs.as<PadAttrs>();
+  int ndim = input_sinfo[0]->ndim;
+  Array<Integer> pad_width = attrs->pad_width;
+  ICHECK(static_cast<int>(pad_width.size()) == 2 * ndim) << "Illegal 
pad_width";
+
+  Array<PrimExpr> out_shape;
+  if (input_sinfo[0]->shape.defined()) {
+    // Compute output shape by adding corresponding pad width to each axis.
+    const auto* data_shape = input_sinfo[0]->shape.as<ShapeExprNode>();
+    for (int i = 0; i < ndim; i++) {
+      // Sum pad width for this axis.
+      PrimExpr added_width = pad_width[2 * i] + pad_width[(2 * i) + 1];
+      const PrimExpr current_width = data_shape->values[i];
+      out_shape.push_back(current_width + added_width);
+    }
+  } else {
+    // Shape isnt defined, best we can do is return ndim and dtype.
+    return TensorStructInfo(input_sinfo[0]->dtype, ndim);
+  }
+  return TensorStructInfo(ShapeExpr(out_shape), input_sinfo[0]->dtype);
+}
+
+TVM_REGISTER_OP("relax.nn.pad")
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("pad_value", "Tensor", "The value to fill in padded area 
with.")
+    .set_attrs_type<PadAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPad)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+/* relax.nn.batchnorm */
 bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx,
                             const Array<TensorStructInfo>& input_sinfo, 
Array<Integer> axes) {
   Op op = Downcast<Op>(call->op);
diff --git a/tests/python/relax/test_frontend_nn_modules.py 
b/tests/python/relax/test_frontend_nn_modules.py
index 77ba550194..68b03c5a21 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -26,6 +26,23 @@ from tvm.script import ir as I
 from tvm.script import relax as R
 
 
+def test_silu():
+    @R.function
+    def forward(
+        x: R.Tensor((3, 3), dtype="float32"),
+        _io: R.Object,
+    ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
+        with R.dataflow():
+            silu: R.Tensor((3, 3), dtype="float32") = R.nn.silu(x)
+            gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)) 
= silu, (_io,)
+            R.output(gv1)
+        return gv1
+
+    mod = modules.SiLU()
+    tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3), 
"float32")}})
+    assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
 def test_linear():
     @R.function
     def forward(
@@ -179,6 +196,91 @@ def test_embedding():
     assert_structural_equal(tvm_mod["forward"], forward, True)
 
 
+def test_timestep_embedding():
+    @R.function
+    def forward(
+        sample: R.Tensor((32, 32), dtype="float32"),
+        condition: R.Tensor((32, 16), dtype="float32"),
+        linear_1_weight: R.Tensor((32, 32), dtype="float32"),
+        linear_1_bias: R.Tensor((32,), dtype="float32"),
+        cond_proj_weight: R.Tensor((32, 16), dtype="float32"),
+        linear_2_weight: R.Tensor((32, 32), dtype="float32"),
+        linear_2_bias: R.Tensor((32,), dtype="float32"),
+        _io: R.Object,
+    ) -> R.Tuple(R.Tensor((32, 32), dtype="float32"), R.Tuple(R.Object)):
+        with R.dataflow():
+            permute_dims: R.Tensor((16, 32), dtype="float32") = R.permute_dims(
+                cond_proj_weight, axes=None
+            )
+            matmul: R.Tensor((32, 32), dtype="float32") = R.matmul(
+                condition, permute_dims, out_dtype="void"
+            )
+            add: R.Tensor((32, 32), dtype="float32") = R.add(sample, matmul)
+            permute_dims1: R.Tensor((32, 32), dtype="float32") = 
R.permute_dims(
+                linear_1_weight, axes=None
+            )
+            matmul1: R.Tensor((32, 32), dtype="float32") = R.matmul(
+                add, permute_dims1, out_dtype="void"
+            )
+            add1: R.Tensor((32, 32), dtype="float32") = R.add(matmul1, 
linear_1_bias)
+            silu: R.Tensor((32, 32), dtype="float32") = R.nn.silu(add1)
+            permute_dims2: R.Tensor((32, 32), dtype="float32") = 
R.permute_dims(
+                linear_2_weight, axes=None
+            )
+            matmul2: R.Tensor((32, 32), dtype="float32") = R.matmul(
+                silu, permute_dims2, out_dtype="void"
+            )
+            add2: R.Tensor((32, 32), dtype="float32") = R.add(matmul2, 
linear_2_bias)
+            gv1: R.Tuple(R.Tensor((32, 32), dtype="float32"), 
R.Tuple(R.Object)) = add2, (_io,)
+            R.output(gv1)
+        return gv1
+
+    mod = modules.TimestepEmbedding(32, 32, cond_proj_dim=16)
+    tvm_mod, _ = mod.export_tvm(
+        spec={
+            "forward": {
+                "sample": spec.Tensor((32, 32), "float32"),
+                "condition": spec.Tensor((32, 16), "float32"),
+            }
+        }
+    )
+    assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
+def test_timesteps():
+    @R.function
+    def forward(
+        x: R.Tensor((3,), dtype="float32"), _io: R.Object
+    ) -> R.Tuple(R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)):
+        with R.dataflow():
+            lv1: R.Tensor((3,), dtype="float32") = R.astype(x, dtype="float32")
+            lv2: R.Tensor((3, 1), dtype="float32") = R.expand_dims(lv1, 
axis=[1])
+            lv3: R.Tensor((5,), dtype="float32") = R.arange(
+                R.prim_value(0), R.prim_value(5), R.prim_value(1), 
dtype="float32"
+            )
+            lv4: R.Tensor((5,), dtype="float32") = R.multiply(
+                R.const(-9.2103404998779297, "float32"), lv3
+            )
+            lv5: R.Tensor((5,), dtype="float32") = R.divide(lv4, R.const(4, 
"float32"))
+            lv6: R.Tensor((5,), dtype="float32") = R.exp(lv5)
+            lv7: R.Tensor((1, 5), dtype="float32") = R.expand_dims(lv6, 
axis=[0])
+            lv8: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv7)
+            lv9: R.Tensor((3, 5), dtype="float32") = R.sin(lv8)
+            lv10: R.Tensor((3, 5), dtype="float32") = R.cos(lv8)
+            get_timestep_embedding: R.Tensor((3, 10), dtype="float32") = 
R.concat(
+                (lv9, lv10), axis=-1
+            )
+            gv1: R.Tuple(
+                R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)
+            ) = get_timestep_embedding, (_io,)
+            R.output(gv1)
+        return gv1
+
+    mod = modules.Timesteps(10)
+    tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3,), 
"float32")}})
+    assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
 def test_kv_cache():
     @I.ir_module
     class Module:
diff --git a/tests/python/relax/test_frontend_nn_op.py 
b/tests/python/relax/test_frontend_nn_op.py
index 048a671101..c404c18a68 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -259,6 +259,45 @@ def test_create():
     tvm.ir.assert_structural_equal(irmodule["test"], test)
 
 
+def test_timestep_embedding():
+    class Model(Module):
+        def test(self, x: Tensor):
+            get_timestep_out = op.get_timestep_embedding(x, 10)
+            return get_timestep_out
+
+    @R.function
+    def test(
+        x: R.Tensor((3,), dtype="float32"), _io: R.Object
+    ) -> R.Tuple(R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)):
+        with R.dataflow():
+            lv1: R.Tensor((3,), dtype="float32") = R.astype(x, dtype="float32")
+            lv2: R.Tensor((3, 1), dtype="float32") = R.expand_dims(lv1, 
axis=[1])
+            lv3: R.Tensor((5,), dtype="float32") = R.arange(
+                R.prim_value(0), R.prim_value(5), R.prim_value(1), 
dtype="float32"
+            )
+            lv4: R.Tensor((5,), dtype="float32") = R.multiply(
+                R.const(-9.2103404998779297, "float32"), lv3
+            )
+            lv5: R.Tensor((5,), dtype="float32") = R.divide(lv4, R.const(4, 
"float32"))
+            lv6: R.Tensor((5,), dtype="float32") = R.exp(lv5)
+            lv7: R.Tensor((1, 5), dtype="float32") = R.expand_dims(lv6, 
axis=[0])
+            lv8: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv7)
+            lv9: R.Tensor((3, 5), dtype="float32") = R.sin(lv8)
+            lv10: R.Tensor((3, 5), dtype="float32") = R.cos(lv8)
+            get_timestep_embedding: R.Tensor((3, 10), dtype="float32") = 
R.concat(
+                (lv9, lv10), axis=-1
+            )
+            gv1: R.Tuple(
+                R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)
+            ) = get_timestep_embedding, (_io,)
+            R.output(gv1)
+        return gv1
+
+    m = Model()
+    irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([3], 
"float32")}})
+    tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
 def test_tensor_expr_op():
     class Model(Module):
         def test(self, x: Tensor):
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index f69045421e..de1cf079a5 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -32,6 +32,7 @@ def test_op_correctness():
     assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax")
     assert relax.op.nn.log_softmax(x).op == Op.get("relax.nn.log_softmax")
     assert relax.op.nn.dropout(x).op == Op.get("relax.nn.dropout")
+    assert relax.op.nn.pad(x, (1, 1, 1, 1)).op == Op.get("relax.nn.pad")
 
     x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
     gamma = relax.Var("gamma", R.Tensor((3,), "float32"))
@@ -1763,5 +1764,30 @@ def test_nll_loss_infer_struct_info_wrong_reduction():
         bb.normalize(relax.op.nn.nll_loss(x, y, w, reduction="foo"))
 
 
+def test_pad_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=2))
+
+    pad_width0 = (0, 0, 0, 0)
+    pad_width1 = (1, 1, 1, 1)
+    pad_width2 = (0, 1, 1, 0)
+
+    _check_inference(bb, relax.op.nn.pad(x, pad_width0), 
relax.TensorStructInfo((2, 3), "float32"))
+    _check_inference(
+        bb,
+        relax.op.nn.pad(x, pad_width1),
+        relax.TensorStructInfo((4, 5), dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.pad(x, pad_width2),
+        relax.TensorStructInfo((3, 4), dtype="float32"),
+    )
+    _check_inference(
+        bb, relax.op.nn.pad(x1, pad_width1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index b713712761..63e79c12cb 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -3508,5 +3508,47 @@ def test_nll_loss_symbolic():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_pad():
+    @tvm.script.ir_module
+    class Pad:
+        @R.function
+        def main(x: R.Tensor((2, 128, 28), "float32")) -> R.Tensor((2, 130, 
30), "float32"):
+            gv: R.Tensor((2, 130, 30), "float32") = R.nn.pad(x, (0, 0, 1, 1, 
1, 1))
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((2, 128, 28), dtype="float32")
+        ) -> R.Tensor((2, 130, 30), dtype="float32"):
+            gv = R.call_tir(Expected.pad, (x), out_sinfo=R.Tensor((2, 130, 
30), dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def pad(
+            A: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"),
+            PadInput: T.Buffer((T.int64(2), T.int64(130), T.int64(30)), 
"float32"),
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            # with T.block("root"):
+            for i0, i1, i2 in T.grid(T.int64(2), T.int64(130), T.int64(30)):
+                with T.block("PadInput"):
+                    v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+                    T.reads(A[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1)])
+                    T.writes(PadInput[v_i0, v_i1, v_i2])
+                    PadInput[v_i0, v_i1, v_i2] = T.if_then_else(
+                        T.int64(1) <= v_i1
+                        and v_i1 < T.int64(129)
+                        and T.int64(1) <= v_i2
+                        and v_i2 < T.int64(29),
+                        A[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1)],
+                        T.float32(0),
+                    )
+
+    mod = LegalizeOps()(Pad)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to