This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 7d1a101e24 [Unity][Frontend][NN] Add Timesteps layer to NN Module API
(#15603)
7d1a101e24 is described below
commit 7d1a101e242bd1c0bb6ac2151528ef998f94c5ad
Author: Josh Fromm <[email protected]>
AuthorDate: Thu Aug 24 23:35:58 2023 -0700
[Unity][Frontend][NN] Add Timesteps layer to NN Module API (#15603)
* Add support for relax.nn.pad
* Formatting
* Add nn.pad testing
* Add support for Timesteps layer
* Simplify arithmetic
* Format
* Bump ci
* Format
* Add TimestepEmbedding layer
* Add TimestepEmbedding tests
* Fix cast style
* Add missing docstring
---
include/tvm/relax/attrs/nn.h | 20 +++-
python/tvm/relax/frontend/nn/modules.py | 97 ++++++++++++++++++++
python/tvm/relax/frontend/nn/op.py | 67 +++++++++++++-
python/tvm/relax/op/nn/nn.py | 31 ++++++-
python/tvm/relax/transform/legalize_ops/nn.py | 16 ++++
src/relax/op/nn/nn.cc | 46 ++++++++++
tests/python/relax/test_frontend_nn_modules.py | 102 +++++++++++++++++++++
tests/python/relax/test_frontend_nn_op.py | 39 ++++++++
tests/python/relax/test_op_nn.py | 26 ++++++
.../python/relax/test_transform_legalize_ops_nn.py | 42 +++++++++
10 files changed, 480 insertions(+), 6 deletions(-)
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 2dc610f654..0d895dccb1 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -375,7 +375,7 @@ struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
}
}; // struct DropoutAttrs
-/*! \brief Attributes used in dropout operator */
+/*! \brief Attributes used in Attention operator */
struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> {
Optional<FloatImm> scale;
Optional<String> causal_mask;
@@ -388,6 +388,24 @@ struct AttentionAttrs : public
tvm::AttrsNode<AttentionAttrs> {
}
}; // struct AttentionAttrs
+/*! \brief Attributes used for the padding operator */
+struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
+ Array<Integer> pad_width;
+ tvm::String pad_mode;
+
+ TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
+ TVM_ATTR_FIELD(pad_width).describe(
+ "Number of values padded to the edges of each axis, "
+ "in the format of (before_1, after_1, ..., before_N, after_N)");
+ TVM_ATTR_FIELD(pad_mode)
+ .set_default("constant")
+ .describe(
+ "Padding type to use. \"constant\" pads with constant_value, "
+ "\"edge\" pads using the edge values of the input array, "
+ "\"reflect\" pads by reflecting values with respect to the
edges.");
+ }
+};
+
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/frontend/nn/modules.py
b/python/tvm/relax/frontend/nn/modules.py
index cd94c23115..fde18473ee 100644
--- a/python/tvm/relax/frontend/nn/modules.py
+++ b/python/tvm/relax/frontend/nn/modules.py
@@ -69,6 +69,15 @@ def _print(_, array: NDArray) -> None:
print(f"effect.print: shape = {array.shape}, dtype = {array.dtype}, data
=\n{array}")
+class SiLU(Module):
+ """
+ Module for SiLU activation layer.
+ """
+
+ def forward(self, x: Tensor):
+ return op.silu(x)
+
+
class Linear(Module):
"""
Module for linear layer.
@@ -524,3 +533,91 @@ class Embedding(Module):
),
shape=[*x.shape, self.dim], # TODO(@junrushao): revisit and
remove self.dim
)
+
+
+class TimestepEmbedding(Module):
+ """
+ Module for HF TimestepEmbedding layer.
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ time_embed_dim: int,
+ act_fn: str = "silu",
+ out_dim: int = None,
+ post_act_fn: Optional[str] = None,
+ cond_proj_dim: Optional[int] = None,
+ ):
+ self.linear_1 = Linear(in_channels, time_embed_dim)
+
+ if cond_proj_dim is not None:
+ self.cond_proj = Linear(cond_proj_dim, in_channels, bias=False)
+ else:
+ self.cond_proj = None
+
+ assert act_fn == "silu", "Only SiLU activations are supported."
+ self.act = SiLU()
+
+ if out_dim is not None:
+ time_embed_dim_out = out_dim
+ else:
+ time_embed_dim_out = time_embed_dim
+
+ self.linear_2 = Linear(time_embed_dim, time_embed_dim_out)
+
+ if post_act_fn is None:
+ self.post_act = None
+ else:
+ assert self.post_act == "silu", "Only SiLU post-activation
supported."
+ self.post_act = SiLU()
+
+ def forward(self, sample: Tensor, condition: Optional[Tensor] = None):
+ """
+ Forward method for TimestepEmbedding layer.
+
+ Parameters
+ ----------
+ sample : Tensor
+ The input timestep that should be looked up.
+ condition : Optional[Tensor]
+ Optional additional projection matrix.
+
+ Returns
+ -------
+ ret : Tensor
+ The resulting embedding lookup for the input sample.
+ """
+ if condition is not None:
+ sample = sample + self.cond_proj(condition)
+ sample = self.linear_1(sample)
+
+ if self.act is not None:
+ sample = self.act(sample)
+
+ sample = self.linear_2(sample)
+
+ if self.post_act is not None:
+ sample = self.post_act(sample)
+ return sample
+
+
+class Timesteps(Module):
+ """
+ Module for HF timesteps layer.
+ """
+
+ def __init__(
+ self, num_channels: int, flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1
+ ):
+ self.num_channels = num_channels
+ self.flip_sin_to_cos = flip_sin_to_cos
+ self.downscale_freq_shift = downscale_freq_shift
+
+ def forward(self, x: Tensor):
+ return op.get_timestep_embedding(
+ x,
+ embedding_dim=self.num_channels,
+ flip_sin_to_cos=self.flip_sin_to_cos,
+ downscale_freq_shift=self.downscale_freq_shift,
+ )
diff --git a/python/tvm/relax/frontend/nn/op.py
b/python/tvm/relax/frontend/nn/op.py
index e5485d54c7..b3959cd95f 100644
--- a/python/tvm/relax/frontend/nn/op.py
+++ b/python/tvm/relax/frontend/nn/op.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=too-many-lines,invalid-name,protected-access
"""nn.Tensor operators."""
+import math
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
from tvm import tir as _tir
@@ -859,12 +860,10 @@ def full(
result : Tensor
The result tensor.
"""
- from tvm import relax # pylint: disable=import-outside-toplevel
-
if isinstance(fill_value, (_tir.FloatImm, _tir.IntImm)):
- fill_value = relax.const(fill_value.value, dtype=dtype)
+ fill_value = rx.const(fill_value.value, dtype=dtype)
elif isinstance(fill_value, (int, float)):
- fill_value = relax.const(fill_value, dtype=dtype)
+ fill_value = rx.const(fill_value, dtype=dtype)
else:
fill_value = fill_value._expr
return _wrap_nested(_op.full(shape, fill_value, dtype), name)
@@ -896,6 +895,66 @@ def zeros(
return _wrap_nested(_op.zeros(shape, dtype), name)
+def get_timestep_embedding(
+ x: Tensor,
+ embedding_dim: int,
+ flip_sin_to_cos: bool = False,
+ downscale_freq_shift: float = 1,
+ scale: float = 1,
+ max_period: int = 10000,
+ name: str = "get_timestep_embedding",
+) -> Tensor:
+ """
+ Timestep calculation as described in Denoising Diffusion Probabilistic
Models.
+
+ Parameters
+ ----------
+ x : Tensor
+ A 1-D Tensor of N indices.
+ embedding_dim : int
+ The dimension of the output.
+ flip_sin_to_cos : bool
+ If True, change the order of sine and cosine embeddings.
+ downscale_freq_shift : float
+ Adjusts the frequency of the sinusoidal sampling.
+ scale : float
+ Weight adjustment for embedding magnitude.
+ max_period : int
+ Controls the minimum frequency of the embeddings.
+ name : str
+ The name to label this operator with.
+
+ Returns
+ -------
+ result : Tensor
+ [N x dim] Tensor of positional embeddings.
+ """
+ timesteps = _op.astype(x._expr, "float32")
+
+ half_dim = embedding_dim // 2
+ exponent = rx.const(-math.log(max_period), "float32") * _op.arange(
+ start=0, end=half_dim, dtype="float32"
+ )
+ exponent = exponent / (rx.const(half_dim - downscale_freq_shift,
"float32"))
+
+ emb = _op.exp(exponent)
+ emb = _op.expand_dims(timesteps, 1) * _op.expand_dims(emb, 0)
+ # Scale embeddings
+ if scale != 1:
+ emb = rx.const(scale, "float32") * emb
+
+ # Concat sine and cosine embeddings.
+ if flip_sin_to_cos:
+ emb = _op.concat([_op.cos(emb), _op.sin(emb)], axis=-1)
+ else:
+ emb = _op.concat([_op.sin(emb), _op.cos(emb)], axis=-1)
+
+ # Zero pad
+ if embedding_dim % 2 == 1:
+ emb = _op.nn.pad(emb, (0, 1, 0, 0))
+ return _wrap_nested(emb, name)
+
+
def tensor_expr_op(
tensor_expr_func: Callable,
name_hint: str,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 1a4c3cceae..cb2718a654 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -20,7 +20,7 @@ from typing import List, Optional, Tuple, Union
from tvm import DataType
from tvm.tir import FloatImm
-from ...expr import Expr
+from ...expr import Expr, const
from . import _ffi_api
@@ -413,6 +413,35 @@ def conv2d_transpose(
)
+def pad(data, pad_width, pad_value=0, pad_mode="constant"):
+ r"""Padding
+
+ This operator takes in a tensor and pads each axis by the specified
+ widths using the specified value.
+
+ Parameters
+ ----------
+ data: relax.Expr
+ The input data to the operator
+ pad_width: tuple of <tuple of <int>>, required
+ Number of values padded to the edges of each axis, in the format
+ of ((before_1, after_1), ..., (before_N, after_N))
+ pad_value: float
+ The value used for padding
+ pad_mode: 'constant', 'edge', 'reflect'
+ 'constant' pads with constant_value pad_value
+ 'edge' pads using the edge values of the input array
+ 'reflect' pads by reflecting values with respect to the edge
+ Returns
+ -------
+ result : relax.Expr
+ The computed result.
+ """
+ if not isinstance(pad_value, Expr):
+ pad_value = const(pad_value)
+ return _ffi_api.pad(data, pad_width, pad_value, pad_mode)
+
+
def max_pool2d(
data: Expr,
pool_size: Union[int, Tuple[int, int]] = (1, 1),
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index 562b497cb2..6b2f25d0a2 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -184,6 +184,22 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) ->
Expr:
)
+@register_legalize("relax.nn.pad")
+def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
+ # Unpack pad_width into two separate lists for topi.
+ pad_widths = call.attrs.pad_width
+ pad_before = pad_widths[::2]
+ pad_after = pad_widths[1::2]
+ return bb.call_te(
+ topi.nn.pad,
+ call.args[0],
+ pad_before=pad_before,
+ pad_after=pad_after,
+ pad_value=float(call.args[1].data.numpy()),
+ primfunc_name_hint="pad",
+ )
+
+
@register_legalize("relax.nn.max_pool2d")
def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.layout:
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index cbd71da848..2e73297e2d 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -123,6 +123,52 @@ TVM_REGISTER_OP("relax.nn.log_softmax")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSoftmax)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.nn.pad */
+TVM_REGISTER_NODE_TYPE(PadAttrs);
+
+Expr pad(Expr data, Array<Integer> pad_width, Expr pad_value, String pad_mode)
{
+ auto attrs = make_object<PadAttrs>();
+ attrs->pad_width = std::move(pad_width);
+ attrs->pad_mode = std::move(pad_mode);
+ static const Op& op = Op::Get("relax.nn.pad");
+ return Call(op, {data, pad_value}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.pad").set_body_typed(pad);
+
+StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) {
+ Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+ const auto* attrs = call->attrs.as<PadAttrs>();
+ int ndim = input_sinfo[0]->ndim;
+ Array<Integer> pad_width = attrs->pad_width;
+ ICHECK(static_cast<int>(pad_width.size()) == 2 * ndim) << "Illegal
pad_width";
+
+ Array<PrimExpr> out_shape;
+ if (input_sinfo[0]->shape.defined()) {
+ // Compute output shape by adding corresponding pad width to each axis.
+ const auto* data_shape = input_sinfo[0]->shape.as<ShapeExprNode>();
+ for (int i = 0; i < ndim; i++) {
+ // Sum pad width for this axis.
+ PrimExpr added_width = pad_width[2 * i] + pad_width[(2 * i) + 1];
+ const PrimExpr current_width = data_shape->values[i];
+ out_shape.push_back(current_width + added_width);
+ }
+ } else {
+ // Shape isnt defined, best we can do is return ndim and dtype.
+ return TensorStructInfo(input_sinfo[0]->dtype, ndim);
+ }
+ return TensorStructInfo(ShapeExpr(out_shape), input_sinfo[0]->dtype);
+}
+
+TVM_REGISTER_OP("relax.nn.pad")
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("pad_value", "Tensor", "The value to fill in padded area
with.")
+ .set_attrs_type<PadAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPad)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+/* relax.nn.batchnorm */
bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx,
const Array<TensorStructInfo>& input_sinfo,
Array<Integer> axes) {
Op op = Downcast<Op>(call->op);
diff --git a/tests/python/relax/test_frontend_nn_modules.py
b/tests/python/relax/test_frontend_nn_modules.py
index 77ba550194..68b03c5a21 100644
--- a/tests/python/relax/test_frontend_nn_modules.py
+++ b/tests/python/relax/test_frontend_nn_modules.py
@@ -26,6 +26,23 @@ from tvm.script import ir as I
from tvm.script import relax as R
+def test_silu():
+ @R.function
+ def forward(
+ x: R.Tensor((3, 3), dtype="float32"),
+ _io: R.Object,
+ ) -> R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object)):
+ with R.dataflow():
+ silu: R.Tensor((3, 3), dtype="float32") = R.nn.silu(x)
+ gv1: R.Tuple(R.Tensor((3, 3), dtype="float32"), R.Tuple(R.Object))
= silu, (_io,)
+ R.output(gv1)
+ return gv1
+
+ mod = modules.SiLU()
+ tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3, 3),
"float32")}})
+ assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
def test_linear():
@R.function
def forward(
@@ -179,6 +196,91 @@ def test_embedding():
assert_structural_equal(tvm_mod["forward"], forward, True)
+def test_timestep_embedding():
+ @R.function
+ def forward(
+ sample: R.Tensor((32, 32), dtype="float32"),
+ condition: R.Tensor((32, 16), dtype="float32"),
+ linear_1_weight: R.Tensor((32, 32), dtype="float32"),
+ linear_1_bias: R.Tensor((32,), dtype="float32"),
+ cond_proj_weight: R.Tensor((32, 16), dtype="float32"),
+ linear_2_weight: R.Tensor((32, 32), dtype="float32"),
+ linear_2_bias: R.Tensor((32,), dtype="float32"),
+ _io: R.Object,
+ ) -> R.Tuple(R.Tensor((32, 32), dtype="float32"), R.Tuple(R.Object)):
+ with R.dataflow():
+ permute_dims: R.Tensor((16, 32), dtype="float32") = R.permute_dims(
+ cond_proj_weight, axes=None
+ )
+ matmul: R.Tensor((32, 32), dtype="float32") = R.matmul(
+ condition, permute_dims, out_dtype="void"
+ )
+ add: R.Tensor((32, 32), dtype="float32") = R.add(sample, matmul)
+ permute_dims1: R.Tensor((32, 32), dtype="float32") =
R.permute_dims(
+ linear_1_weight, axes=None
+ )
+ matmul1: R.Tensor((32, 32), dtype="float32") = R.matmul(
+ add, permute_dims1, out_dtype="void"
+ )
+ add1: R.Tensor((32, 32), dtype="float32") = R.add(matmul1,
linear_1_bias)
+ silu: R.Tensor((32, 32), dtype="float32") = R.nn.silu(add1)
+ permute_dims2: R.Tensor((32, 32), dtype="float32") =
R.permute_dims(
+ linear_2_weight, axes=None
+ )
+ matmul2: R.Tensor((32, 32), dtype="float32") = R.matmul(
+ silu, permute_dims2, out_dtype="void"
+ )
+ add2: R.Tensor((32, 32), dtype="float32") = R.add(matmul2,
linear_2_bias)
+ gv1: R.Tuple(R.Tensor((32, 32), dtype="float32"),
R.Tuple(R.Object)) = add2, (_io,)
+ R.output(gv1)
+ return gv1
+
+ mod = modules.TimestepEmbedding(32, 32, cond_proj_dim=16)
+ tvm_mod, _ = mod.export_tvm(
+ spec={
+ "forward": {
+ "sample": spec.Tensor((32, 32), "float32"),
+ "condition": spec.Tensor((32, 16), "float32"),
+ }
+ }
+ )
+ assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
+def test_timesteps():
+ @R.function
+ def forward(
+ x: R.Tensor((3,), dtype="float32"), _io: R.Object
+ ) -> R.Tuple(R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)):
+ with R.dataflow():
+ lv1: R.Tensor((3,), dtype="float32") = R.astype(x, dtype="float32")
+ lv2: R.Tensor((3, 1), dtype="float32") = R.expand_dims(lv1,
axis=[1])
+ lv3: R.Tensor((5,), dtype="float32") = R.arange(
+ R.prim_value(0), R.prim_value(5), R.prim_value(1),
dtype="float32"
+ )
+ lv4: R.Tensor((5,), dtype="float32") = R.multiply(
+ R.const(-9.2103404998779297, "float32"), lv3
+ )
+ lv5: R.Tensor((5,), dtype="float32") = R.divide(lv4, R.const(4,
"float32"))
+ lv6: R.Tensor((5,), dtype="float32") = R.exp(lv5)
+ lv7: R.Tensor((1, 5), dtype="float32") = R.expand_dims(lv6,
axis=[0])
+ lv8: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv7)
+ lv9: R.Tensor((3, 5), dtype="float32") = R.sin(lv8)
+ lv10: R.Tensor((3, 5), dtype="float32") = R.cos(lv8)
+ get_timestep_embedding: R.Tensor((3, 10), dtype="float32") =
R.concat(
+ (lv9, lv10), axis=-1
+ )
+ gv1: R.Tuple(
+ R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)
+ ) = get_timestep_embedding, (_io,)
+ R.output(gv1)
+ return gv1
+
+ mod = modules.Timesteps(10)
+ tvm_mod, _ = mod.export_tvm(spec={"forward": {"x": spec.Tensor((3,),
"float32")}})
+ assert_structural_equal(tvm_mod["forward"], forward, True)
+
+
def test_kv_cache():
@I.ir_module
class Module:
diff --git a/tests/python/relax/test_frontend_nn_op.py
b/tests/python/relax/test_frontend_nn_op.py
index 048a671101..c404c18a68 100644
--- a/tests/python/relax/test_frontend_nn_op.py
+++ b/tests/python/relax/test_frontend_nn_op.py
@@ -259,6 +259,45 @@ def test_create():
tvm.ir.assert_structural_equal(irmodule["test"], test)
+def test_timestep_embedding():
+ class Model(Module):
+ def test(self, x: Tensor):
+ get_timestep_out = op.get_timestep_embedding(x, 10)
+ return get_timestep_out
+
+ @R.function
+ def test(
+ x: R.Tensor((3,), dtype="float32"), _io: R.Object
+ ) -> R.Tuple(R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)):
+ with R.dataflow():
+ lv1: R.Tensor((3,), dtype="float32") = R.astype(x, dtype="float32")
+ lv2: R.Tensor((3, 1), dtype="float32") = R.expand_dims(lv1,
axis=[1])
+ lv3: R.Tensor((5,), dtype="float32") = R.arange(
+ R.prim_value(0), R.prim_value(5), R.prim_value(1),
dtype="float32"
+ )
+ lv4: R.Tensor((5,), dtype="float32") = R.multiply(
+ R.const(-9.2103404998779297, "float32"), lv3
+ )
+ lv5: R.Tensor((5,), dtype="float32") = R.divide(lv4, R.const(4,
"float32"))
+ lv6: R.Tensor((5,), dtype="float32") = R.exp(lv5)
+ lv7: R.Tensor((1, 5), dtype="float32") = R.expand_dims(lv6,
axis=[0])
+ lv8: R.Tensor((3, 5), dtype="float32") = R.multiply(lv2, lv7)
+ lv9: R.Tensor((3, 5), dtype="float32") = R.sin(lv8)
+ lv10: R.Tensor((3, 5), dtype="float32") = R.cos(lv8)
+ get_timestep_embedding: R.Tensor((3, 10), dtype="float32") =
R.concat(
+ (lv9, lv10), axis=-1
+ )
+ gv1: R.Tuple(
+ R.Tensor((3, 10), dtype="float32"), R.Tuple(R.Object)
+ ) = get_timestep_embedding, (_io,)
+ R.output(gv1)
+ return gv1
+
+ m = Model()
+ irmodule, _ = m.export_tvm(spec={"test": {"x": spec.Tensor([3],
"float32")}})
+ tvm.ir.assert_structural_equal(irmodule["test"], test)
+
+
def test_tensor_expr_op():
class Model(Module):
def test(self, x: Tensor):
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index f69045421e..de1cf079a5 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -32,6 +32,7 @@ def test_op_correctness():
assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax")
assert relax.op.nn.log_softmax(x).op == Op.get("relax.nn.log_softmax")
assert relax.op.nn.dropout(x).op == Op.get("relax.nn.dropout")
+ assert relax.op.nn.pad(x, (1, 1, 1, 1)).op == Op.get("relax.nn.pad")
x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
gamma = relax.Var("gamma", R.Tensor((3,), "float32"))
@@ -1763,5 +1764,30 @@ def test_nll_loss_infer_struct_info_wrong_reduction():
bb.normalize(relax.op.nn.nll_loss(x, y, w, reduction="foo"))
+def test_pad_infer_struct_info():
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", R.Tensor((2, 3), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=2))
+
+ pad_width0 = (0, 0, 0, 0)
+ pad_width1 = (1, 1, 1, 1)
+ pad_width2 = (0, 1, 1, 0)
+
+ _check_inference(bb, relax.op.nn.pad(x, pad_width0),
relax.TensorStructInfo((2, 3), "float32"))
+ _check_inference(
+ bb,
+ relax.op.nn.pad(x, pad_width1),
+ relax.TensorStructInfo((4, 5), dtype="float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.pad(x, pad_width2),
+ relax.TensorStructInfo((3, 4), dtype="float32"),
+ )
+ _check_inference(
+ bb, relax.op.nn.pad(x1, pad_width1),
relax.TensorStructInfo(dtype="float32", ndim=2)
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index b713712761..63e79c12cb 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -3508,5 +3508,47 @@ def test_nll_loss_symbolic():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_pad():
+ @tvm.script.ir_module
+ class Pad:
+ @R.function
+ def main(x: R.Tensor((2, 128, 28), "float32")) -> R.Tensor((2, 130,
30), "float32"):
+ gv: R.Tensor((2, 130, 30), "float32") = R.nn.pad(x, (0, 0, 1, 1,
1, 1))
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 128, 28), dtype="float32")
+ ) -> R.Tensor((2, 130, 30), dtype="float32"):
+ gv = R.call_tir(Expected.pad, (x), out_sinfo=R.Tensor((2, 130,
30), dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def pad(
+ A: T.Buffer((T.int64(2), T.int64(128), T.int64(28)), "float32"),
+ PadInput: T.Buffer((T.int64(2), T.int64(130), T.int64(30)),
"float32"),
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for i0, i1, i2 in T.grid(T.int64(2), T.int64(130), T.int64(30)):
+ with T.block("PadInput"):
+ v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
+ T.reads(A[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1)])
+ T.writes(PadInput[v_i0, v_i1, v_i2])
+ PadInput[v_i0, v_i1, v_i2] = T.if_then_else(
+ T.int64(1) <= v_i1
+ and v_i1 < T.int64(129)
+ and T.int64(1) <= v_i2
+ and v_i2 < T.int64(29),
+ A[v_i0, v_i1 - T.int64(1), v_i2 - T.int64(1)],
+ T.float32(0),
+ )
+
+ mod = LegalizeOps()(Pad)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()