This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 7dd27f075832f9ea9d89d1116f637de8480bd5c9
Author: Ruihang Lai <ruiha...@cs.cmu.edu>
AuthorDate: Tue Feb 14 15:03:22 2023 -0500

    [Unity] Relax op: neural networks (#13993)
    
    This PR is about the high-level tensor computation operators in Relax.
    
    This PR includes the neural network operators.
---
 include/tvm/relax/attrs/nn.h                       | 190 +++++
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/{ => nn}/__init__.py           |  31 +-
 .../tvm/relax/op/{__init__.py => nn/_ffi_api.py}   |  30 +-
 python/tvm/relax/op/nn/nn.py                       | 524 ++++++++++++
 python/tvm/relax/op/op_attrs.py                    |  35 +
 python/tvm/script/ir_builder/relax/ir.py           |   2 +
 src/relax/op/nn/convolution.cc                     | 146 ++++
 src/relax/op/nn/convolution.h                      |  63 ++
 src/relax/op/nn/nn.cc                              | 245 ++++++
 src/relax/op/nn/nn.h                               |  81 ++
 src/relax/op/nn/pooling.cc                         | 184 ++++
 src/relax/op/nn/pooling.h                          |  46 +
 tests/python/relax/test_op_nn.py                   | 929 +++++++++++++++++++++
 tests/python/relax/test_op_nn_convolution.py       | 429 ++++++++++
 tests/python/relax/test_op_nn_pooling.py           | 429 ++++++++++
 tests/python/relax/test_tvmscript_parser_op_nn.py  | 193 +++++
 17 files changed, 3503 insertions(+), 55 deletions(-)

diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
new file mode 100644
index 0000000000..694a510706
--- /dev/null
+++ b/include/tvm/relax/attrs/nn.h
@@ -0,0 +1,190 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/attrs/nn.h
+ * \brief Attributes for neural network operators.
+ */
+#ifndef TVM_RELAX_ATTRS_NN_H_
+#define TVM_RELAX_ATTRS_NN_H_
+
+#include <tvm/relax/expr.h>
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Attributes used in Conv2d operator */
+struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
+  Array<IntImm> strides;
+  Array<IntImm> padding;
+  Array<IntImm> dilation;
+  int groups;
+  String data_layout;
+  String kernel_layout;
+  String out_layout;
+  DataType out_dtype;
+
+  TVM_DECLARE_ATTRS(Conv2DAttrs, "relax.attrs.Conv2DAttrs") {
+    TVM_ATTR_FIELD(strides).describe("Specifies the strides of the 
convolution.");
+    TVM_ATTR_FIELD(padding).describe(
+        "If padding is non-zero, then the input is implicitly zero-padded"
+        "Padding support both symmetric and asymmetric as"
+        "one int : same padding used on all sides"
+        "two int : bottom, right will use same padding as top, left"
+        "four int : padding width in the order of (top, left, bottom, right)");
+    TVM_ATTR_FIELD(dilation).describe(
+        "Specifies the dilation rate to use for dilated convolution.");
+    TVM_ATTR_FIELD(groups).describe(
+        "Number of groups to split the input into for grouped convolution. The 
number of input and "
+        "output channels should be divisible by the number of groups.");
+    TVM_ATTR_FIELD(data_layout)
+        .describe(
+            "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+            "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+            "dimensions respectively. Convolution is applied on the 'H' and"
+            "'W' dimensions.");
+    TVM_ATTR_FIELD(kernel_layout)
+        .describe(
+            "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
+            "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, 
and width"
+            "dimensions respectively.");
+    TVM_ATTR_FIELD(out_layout)
+        .describe(
+            "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
+            "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+            "dimensions respectively. Default to be same as input layout.");
+    TVM_ATTR_FIELD(out_dtype).describe(
+        "Output data type, set to explicit type under mixed precision 
setting");
+  }
+};  // struct Conv2dAttrs
+
+/*! \brief Attributes used in max_pool2d operator */
+struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
+  Array<IntImm> pool_size;
+  Array<IntImm> strides;
+  Array<IntImm> padding;
+  Array<IntImm> dilation;
+  bool ceil_mode;
+  String layout;
+  String out_layout;
+
+  TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relax.attrs.MaxPool2DAttrs") {
+    TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
+    TVM_ATTR_FIELD(strides).describe("Specifies the strides of the 
convolution.");
+    TVM_ATTR_FIELD(dilation).describe("Specifies the dilation of the 
convolution.");
+    TVM_ATTR_FIELD(padding).describe(
+        "If padding is non-zero, then the input is implicitly zero-padded"
+        "Padding support both symmetric and asymmetric as"
+        "one int : same padding used on all sides"
+        "two int : bottom, right will use same padding as top, left"
+        "four int : padding width in the order of (top, left, bottom, right)");
+    TVM_ATTR_FIELD(ceil_mode).describe(
+        "A boolean indicating if use ceil or floor to compute the output 
shape. By using ceil, "
+        "every element in the input tensor will be covered by a sliding 
window.");
+    TVM_ATTR_FIELD(layout).describe(
+        "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+        "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+        "dimensions respectively. Pooling is applied on the 'H' and"
+        "'W' dimensions.");
+    TVM_ATTR_FIELD(out_layout)
+        .describe(
+            "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
+            "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+            "dimensions respectively. Pooling is applied on the 'H' and"
+            "'W' dimensions.");
+  }
+};  // struct MaxPool2dAttrs
+
+/*! \brief Attributes for 2d adaptive pool operator */
+struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
+  Optional<Array<IntImm>> output_size;
+  String layout;
+  String out_layout;
+
+  TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relax.attrs.AdaptivePool2DAttrs") {
+    TVM_ATTR_FIELD(output_size).describe("Output height and width.");
+    TVM_ATTR_FIELD(layout).describe(
+        "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+        "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+        "dimensions respectively. Pooling is applied on the 'H' and"
+        "'W' dimensions.");
+    TVM_ATTR_FIELD(out_layout)
+        .describe(
+            "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
+            "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+            "dimensions respectively. Pooling is applied on the 'H' and"
+            "'W' dimensions.");
+  }
+};  // struct AdaptivePool2DAttrs
+
+/*! \brief Attributes used in softmax operators */
+struct SoftmaxAttrs : public tvm::AttrsNode<SoftmaxAttrs> {
+  int axis;
+
+  TVM_DECLARE_ATTRS(SoftmaxAttrs, "relax.attrs.SoftmaxAttrs") {
+    TVM_ATTR_FIELD(axis).describe("The axis to sum over when computing 
softmax.");
+  }
+};
+
+/*! \brief Attributes used in batch_norm operator */
+struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
+  int axis;
+  double epsilon;
+  bool center;
+  bool scale;
+
+  TVM_DECLARE_ATTRS(BatchNormAttrs, "relax.attrs.BatchNormAttrs") {
+    TVM_ATTR_FIELD(axis).describe("The axis along which the normalization is 
applied.");
+    TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid 
dividing by zero");
+    TVM_ATTR_FIELD(center).describe(
+        "Indicating if the beta offset will be added to the normalized 
tensor.");
+    TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be 
multiplied.");
+  }
+};  // struct BatchNormAttrs
+
+/*! \brief Attributes used in layer_norm operator */
+struct LayerNormAttrs : public tvm::AttrsNode<LayerNormAttrs> {
+  Array<Integer> axes;
+  double epsilon;
+  bool center;
+  bool scale;
+
+  TVM_DECLARE_ATTRS(LayerNormAttrs, "relax.attrs.LayerNormAttrs") {
+    TVM_ATTR_FIELD(axes).describe("The axes that along which the normalization 
is applied.");
+    TVM_ATTR_FIELD(epsilon).describe("Small float added to variance to avoid 
dividing by zero");
+    TVM_ATTR_FIELD(center).describe(
+        "Indicating if the beta offset will be added to the normalized 
tensor.");
+    TVM_ATTR_FIELD(scale).describe("Indicating if the gamma scale will be 
multiplied.");
+  }
+};  // struct LayerNormAttrs
+
+/*! \brief Attributes used in dropout operator */
+struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
+  double rate;
+
+  TVM_DECLARE_ATTRS(DropoutAttrs, "relax.attrs.DropoutAttrs") {
+    TVM_ATTR_FIELD(rate).describe(
+        "Fraction of the input that gets dropped out during training time");
+  }
+};  // struct DropoutAttrs
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_ATTRS_NN_H_
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 68152c2056..6c6fffc7c6 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -31,6 +31,7 @@ from .unary import *
 from . import builtin
 from . import image
 from . import memory
+from . import nn
 
 
 def _register_op_make():
diff --git a/python/tvm/relax/op/__init__.py 
b/python/tvm/relax/op/nn/__init__.py
similarity index 57%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/nn/__init__.py
index 68152c2056..af2aa106bc 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -14,31 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
-
-# Operators
-from .base import *
-from .binary import *
-from .datatype import *
-from .index import *
-from .manipulate import *
-from .op_attrs import *
-from .statistical import *
-from .set import *
-from .ternary import *
-from .unary import *
-from . import builtin
-from . import image
-from . import memory
-
-
-def _register_op_make():
-    # pylint: disable=import-outside-toplevel
-    from . import _ffi_api
-    from .. import expr
-
-    expr._op_ffi_api = _ffi_api  # type: ignore
-
-
-_register_op_make()
+# pylint: disable=wildcard-import
+"""Neural network related operators."""
+from .nn import *
diff --git a/python/tvm/relax/op/__init__.py 
b/python/tvm/relax/op/nn/_ffi_api.py
similarity index 57%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/nn/_ffi_api.py
index 68152c2056..1785345ac1 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/nn/_ffi_api.py
@@ -14,31 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""Constructor APIs"""
+import tvm._ffi
 
-# Operators
-from .base import *
-from .binary import *
-from .datatype import *
-from .index import *
-from .manipulate import *
-from .op_attrs import *
-from .statistical import *
-from .set import *
-from .ternary import *
-from .unary import *
-from . import builtin
-from . import image
-from . import memory
-
-
-def _register_op_make():
-    # pylint: disable=import-outside-toplevel
-    from . import _ffi_api
-    from .. import expr
-
-    expr._op_ffi_api = _ffi_api  # type: ignore
-
-
-_register_op_make()
+tvm._ffi._init_api("relax.op.nn", __name__)
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
new file mode 100644
index 0000000000..cdf0e96464
--- /dev/null
+++ b/python/tvm/relax/op/nn/nn.py
@@ -0,0 +1,524 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Relax Neural Network (NN) operators"""
+from typing import List, Optional, Tuple, Union
+
+from tvm import DataType
+
+from . import _ffi_api
+from ...expr import Expr
+
+
+def conv2d(
+    data: Expr,
+    weight: Expr,
+    strides: Union[int, Tuple[int, int]] = (1, 1),
+    padding: Union[int, Tuple[int, ...]] = (0, 0),
+    dilation: Union[int, Tuple[int, int]] = (1, 1),
+    groups: int = 1,
+    data_layout: str = "NCHW",
+    kernel_layout: str = "OIHW",
+    out_layout: Optional[str] = None,
+    out_dtype: Optional[Union[str, DataType]] = None,
+) -> Expr:
+    r"""2D convolution.
+
+    This operator takes the weight as the convolution kernel
+    and convolves it with data to produce an output.
+
+
+    In the default case, where the data_layout is `NCHW`
+    and kernel_layout is `OIHW`, conv2d takes in
+    a data Tensor with shape `(batch_size, in_channels, height, width)`,
+    and a weight Tensor with shape `(channels, in_channels, kernel_h, 
kernel_w)`,
+    where `kernel_h` and `kernel_w` is the lengths of the `H` and `W` kernel 
dimensions,
+    to produce an output Tensor with the following rule:
+
+    .. math::
+
+        \mbox{out}[b, c, y, x] = \sum_{dy, dx, k}
+           \mbox{data}[b, k, \mbox{strides}[0] * y  + dy, \mbox{strides}[1] * 
x + dx] *
+           \mbox{weight}[c, k, dy, dx]
+
+    Padding and dilation are applied to data and weight respectively before 
the computation.
+    This operator accepts data layout specification.
+    Semantically, the operator will convert the layout to the canonical layout
+    (`NCHW` for data and `OIHW` for weight), perform the computation,
+    then convert to the out_layout.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    weight : relax.Expr
+        The weight expressions.
+
+    strides : Union[int, Tuple[int, int]]
+        The strides of convolution. It is required to have length either 1 or 
2.
+
+    padding : Union[int, Tuple[int, ...]]
+        The padding of convolution on both sides of inputs before convolution.
+        It is required to have length either 1, 2 or 4.
+
+    dilation : Union[int, Tuple[int, int]]
+        Specifies the dilation rate to be used for dilated convolution.
+        It is required to have length either 1 or 2.
+
+    groups : int
+        Number of groups to split the input into for grouped convolution.
+        The number of input and output channels should be divisible by the 
number of groups.
+
+    data_layout : str
+        Layout of the input.
+
+    kernel_layout : str
+        Layout of the weight.
+
+    out_layout : Optional[str]
+        Layout of the output. If not specified, it is the same as data_layout
+
+    out_dtype : Optional[Union[str, DataType]]
+        Specifies the output data type for mixed precision conv2d.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(strides, int):
+        strides = (strides, strides)
+    if isinstance(dilation, int):
+        dilation = (dilation, dilation)
+    if isinstance(padding, int):
+        padding = (padding, padding, padding, padding)
+
+    return _ffi_api.conv2d(  # type: ignore
+        data,
+        weight,
+        strides,
+        padding,
+        dilation,
+        groups,
+        data_layout,
+        kernel_layout,
+        out_layout,
+        out_dtype,
+    )
+
+
+def max_pool2d(
+    data: Expr,
+    pool_size: Union[int, Tuple[int, int]] = (1, 1),
+    strides: Union[int, Tuple[int, int]] = (1, 1),
+    padding: Union[int, Tuple[int, ...]] = (0, 0),
+    dilation: Union[int, Tuple[int, int]] = (1, 1),
+    ceil_mode: bool = False,
+    layout: str = "NCHW",
+    out_layout: Optional[str] = None,
+) -> Expr:
+    r"""2D maximum pooling operator.
+
+    This operator takes data as input and does 2D max value calculation
+    with in pool_size sized window by striding defined by stride
+
+
+    In the default case, where the data_layout is `NCHW`
+    a data Tensor with shape `(batch_size, in_channels, height, width)`,
+    to produce an output Tensor with the following rule:
+
+    with data of shape (b, c, h, w) and pool_size (kh, kw)
+
+    .. math::
+
+        \mbox{out}(b, c, y, x)  = \max_{m=0, \ldots, kh-1} \max_{n=0, \ldots, 
kw-1}
+             \mbox{data}(b, c, \mbox{stride}[0] * y + m, \mbox{stride}[1] * x 
+ n)
+
+    Padding is applied to data before the computation.
+    ceil_mode is used to take ceil or floor while computing out shape.
+    This operator accepts data layout specification.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    pool_size : Union[int, Tuple[int, int]]
+        The size of window for pooling. It is required to have length either 1 
or 2.
+
+    strides : Union[int, Tuple[int, int]]
+        The strides of pooling. It is required to have length either 1 or 2.
+
+    padding : Union[int, Tuple[int, ...]]
+        The padding for pooling. It is required to have length either 1, 2 or 
4.
+
+    dilation : Union[int, Tuple[int, int]]
+        The dilation of pooling. It is required to have length either 1 or 2.
+
+    ceil_mode : bool
+        A boolean indicating if use ceil or floor to compute the output shape.
+        By using ceil, every element in the input tensor will be covered by a 
sliding window.
+
+    layout : str
+        Layout of the input.
+
+    out_layout : Optional[str]
+        Layout of the output. If not specified, it is the same as data_layout
+
+    Returns
+    -------
+    result : Expr
+        The computed result.
+    """
+    if isinstance(pool_size, int):
+        pool_size = (pool_size, pool_size)
+    if isinstance(strides, int):
+        strides = (strides, strides)
+    if isinstance(dilation, int):
+        dilation = (dilation, dilation)
+    if isinstance(padding, int):
+        padding = (padding, padding, padding, padding)
+
+    return _ffi_api.max_pool2d(  # type: ignore
+        data, pool_size, strides, padding, dilation, ceil_mode, layout, 
out_layout
+    )
+
+
+def adaptive_avg_pool2d(
+    data: Expr,
+    output_size: Optional[Union[int, Tuple[int, int]]] = None,
+    layout: str = "NCHW",
+    out_layout: Optional[str] = None,
+) -> Expr:
+    r"""2D adaptive average pooling operator. This operator is experimental.
+
+    This operator takes data as input and does 2D average value calculation
+    across each window represented by WxH.
+
+
+    In the default case, where the data_layout is `NCHW`
+    a data Tensor with shape `(batch_size, in_channels, height, width)`,
+    to produce an output Tensor with shape
+    (batch_size, in_channels, output_height, output_width).
+
+    The pooling kernel and stride sizes are automatically chosen for
+    desired output sizes.
+
+    For output_size:
+        If this argument is not provided, input height and width will be used
+        as output height and width.
+
+        If a single integer is provided for output_size, the output size is
+        (N x C x output_size x output_size) for any input (NCHW).
+
+        If a tuple of integers (height, width) are provided for output_size,
+        the output size is (N x C x height x width) for any input (NCHW).
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    output_size : Optional[Union[int, Tuple[int, int]]]
+        Output height and width.
+        If not specified, it will be the same as the input height and width.
+        If specified, it is required to have length either 1 or 2.
+
+    layout : str
+        Layout of the input.
+
+    out_layout : Optional[str]
+        Layout of the output. If not specified, it is the same as data_layout
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(output_size, int):
+        output_size = (output_size, output_size)
+    return _ffi_api.adaptive_avg_pool2d(data, output_size, layout, out_layout) 
 # type: ignore
+
+
+def relu(data: Expr) -> Expr:
+    """Rectified linear unit.
+
+    .. math::
+        text{ReLU}(x) = max(x, 0)
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.relu(data)  # type: ignore
+
+
+def gelu(data: Expr) -> Expr:
+    """Gaussian Error Linear Units function
+
+    .. math::
+        text{GeLU}(x) = 0.5 * x * (1 + erf(x * 0.5**0.5))
+
+    where :math:`erf` is the Gauss Error function.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+
+    Note
+    ----
+    The input tensor is required to have float dtype
+    """
+    return _ffi_api.gelu(data)  # type: ignore
+
+
+def silu(data: Expr) -> Expr:
+    """Sigmoid Linear Unit function
+
+    .. math::
+        text{SiLU}(x) = x * sigmoid(x)
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+
+    Note
+    ----
+    The input tensor is required to have float dtype
+    """
+    return _ffi_api.silu(data)  # type: ignore
+
+
+def softmax(data: Expr, axis: int = -1) -> Expr:
+    r"""Computes softmax.
+
+    .. math:: text{softmax}(x)_i = frac{exp(x_i)}{\sum_j exp(x_j)}
+
+    Parameters
+    ----------
+    data: relax.Expr
+        The input data to the operator.
+
+    axis: int
+        The axis to sum over when computing softmax.
+        If not specified, it is by default the last axis of the input tensor.
+        Supports negative indexing.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+
+    Note
+    ----
+    The input tensor is required to have float dtype
+    """
+    return _ffi_api.softmax(data, axis)  # type: ignore
+
+
+def batch_norm(
+    data: Expr,
+    gamma: Expr,
+    beta: Expr,
+    moving_mean: Expr,
+    moving_var: Expr,
+    axis: int,
+    epsilon: float = 1e-5,
+    center: bool = True,
+    scale: bool = True,
+) -> Expr:
+    r"""
+    Batch normalization layer (Ioffe and Szegedy, 2014).
+    Normalizes the input at each batch, i.e. applies a transformation
+    that maintains the mean activation close to 0 and the activation
+    standard deviation close to 1.
+
+    .. math::
+
+        data\_mean[i] = mean(data[:,i,:,...]) \\
+        data\_var[i] = var(data[:,i,:,...])
+
+    Then compute the normalized output, which has the same shape as input, as 
following:
+
+    .. math::
+
+        out[:,i,:,...] = \frac{data[:,i,:,...] - 
data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}}
+            * gamma[i] + beta[i]
+
+    Both *mean* and *var* returns a scalar by treating the input as a vector.
+
+    Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
+    have shape *(k,)*.
+
+    Besides the inputs and the outputs, this operator accepts two auxiliary
+    states, ``moving_mean`` and ``moving_var``, which are *k*-length
+    vectors. They are global statistics for the whole dataset, which are 
updated by
+
+    .. code:: python
+
+        moving_mean = moving_mean * momentum + data_mean * (1 - momentum)
+        moving_var = moving_var * momentum + data_var * (1 - momentum)
+
+    The parameter ``axis`` specifies which axis of the input shape denotes
+    the 'channel' (separately normalized groups).  The default is 1.
+    Specifying -1 sets the channel axis to be the last item in the input shape.
+
+    .. note::
+
+        This operator can be optimized away for inference.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    gamma : relax.Expr
+        The gamma scale factor.
+
+    beta : relax.Expr
+        The beta offset factor.
+
+    moving_mean : relax.Expr
+        Running mean of input.
+
+    moving_var : relax.Expr
+        Running variance of input.
+
+    axis : int
+        The axis along which the normalization is applied.
+
+    epsilon : float
+        Small float added to variance to avoid dividing by zero.
+
+    center : bool
+        Indicating if the beta offset will be added to the normalized tensor.
+
+    scale : bool
+        Indicating if the gamma scale will be multiplied.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.batch_norm(  # type: ignore
+        data, gamma, beta, moving_mean, moving_var, axis, epsilon, center, 
scale
+    )
+
+
+def layer_norm(
+    data: Expr,
+    gamma: Expr,
+    beta: Expr,
+    axes: Union[int, List[int]],
+    epsilon: float = 1e-5,
+    center: bool = True,
+    scale: bool = True,
+) -> Expr:
+    r"""
+    Layer normalization (Lei Ba and et al., 2016).
+    Applies layer normalization to the n-dimensional input array.
+    This operator takes an n-dimensional input array and normalizes
+    the input using the given axis:
+
+    .. math::
+
+        out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis)+\epsilon}}
+            * gamma + beta
+
+    Unlike batch normalization, the mean and var are computed along the 
channel dimension.
+
+    Assume the input has size k on axis 1, then both gamma and beta have shape 
(k,).
+
+    .. note::
+
+        This operator can be optimized away for inference.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        Input to which layer_norm will be applied.
+
+    gamma : relax.Expr
+        The gamma scale factor.
+
+    beta : relax.Expr
+        The beta offset factor.
+
+    axes : Union[int, List[int]]
+        The axes that along which the normalization is applied.
+
+    epsilon : float
+        Small float added to variance to avoid dividing by zero.
+
+    center : bool
+        Indicating if the beta offset will be added to the normalized tensor.
+
+    scale : bool
+        Indicating if the gamma scale will be multiplied.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    if isinstance(axes, int):
+        axes = [axes]
+    return _ffi_api.layer_norm(data, gamma, beta, axes, epsilon, center, 
scale)  # type: ignore
+
+
+def dropout(data: Expr, rate: float = 0.5) -> Expr:
+    """Applies the dropout operation to the input tensor.
+
+    During training, each element of the input is set to zero with
+    probability ``p``. The whole array is scaled by ``1/(1-p)``
+    to keep the expected sum of the input unchanged.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    rate : float
+        The probability for an element to be reset to 0.
+
+    Returns
+    -------
+    result : relax.Expr
+        The result of dropout, which is a tuple of two tensors.
+        The first one is the original tensor and the second one is a
+        mask tensor (1.0 where element not dropped, 0.0 where dropped)
+    """
+    return _ffi_api.dropout(data, rate)  # type: ignore
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index 1fb8853040..68f84b3514 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -34,6 +34,41 @@ class StridedSliceAttrs(Attrs):
     """Attributes used in strided_slice operator"""
 
 
+@tvm._ffi.register_object("relax.attrs.Conv2DAttrs")
+class Conv2DAttrs(Attrs):
+    """Attributes for nn.conv2d"""
+
+
+@tvm._ffi.register_object("relax.attrs.MaxPool2DAttrs")
+class MaxPool2DAttrs(Attrs):
+    """Attributes for nn.max_pool2d"""
+
+
+@tvm._ffi.register_object("relax.attrs.AdaptivePool2DAttrs")
+class AdaptivePool2DAttrs(Attrs):
+    """Attributes for 2d adaptive pool operator"""
+
+
+@tvm._ffi.register_object("relax.attrs.SoftmaxAttrs")
+class SoftmaxAttrs(Attrs):
+    """Attributes for nn.softmax"""
+
+
+@tvm._ffi.register_object("relax.attrs.BatchNormAttrs")
+class BatchNormAttrs(Attrs):
+    """Attributes used in batch_norm operator"""
+
+
+@tvm._ffi.register_object("relax.attrs.LayerNormAttrs")
+class LayerNormAttrs(Attrs):
+    """Attributes used in layer_norm operator"""
+
+
+@tvm._ffi.register_object("relax.attrs.DropoutAttrs")
+class DropoutAttrs(Attrs):
+    """Attributes for dropout operator"""
+
+
 @tvm._ffi.register_object("relax.attrs.StatisticalAttrs")
 class StatisticalAttrs(Attrs):
     """Attributes used in statistical operator"""
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 47779a6024..1f0e31428c 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -93,6 +93,7 @@ from tvm.relax.op import (
     tan,
     tanh,
     unique,
+    nn,
 )
 from tvm.relax.struct_info import StructInfo
 from tvm.relax.utils import args_converter
@@ -530,4 +531,5 @@ __all__ = [
     "tuple",
     "variance",
     "unique",
+    "nn",    
 ]
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
new file mode 100644
index 0000000000..a3ddd3e350
--- /dev/null
+++ b/src/relax/op/nn/convolution.cc
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/op/nn/convolution.cc
+ * \brief Convolution operators
+ */
+
+#include "convolution.h"
+
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+/* relax.nn.conv2d */
+TVM_REGISTER_NODE_TYPE(Conv2DAttrs);
+
+Expr conv2d(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> 
padding,
+            Array<IntImm> dilation, int groups, String data_layout, String 
kernel_layout,
+            Optional<String> out_layout, DataType out_dtype) {
+  padding = GetCompletePadding2D(std::move(padding));
+  if (strides.size() == 1) {
+    strides.push_back(strides[0]);
+  }
+  if (dilation.size() == 1) {
+    dilation.push_back(dilation[0]);
+  }
+
+  CHECK_GT(groups, 0) << "The number of groups in convolution is expected to 
be positive. However, "
+                         "the given number of groups is "
+                      << groups;
+  CHECK_EQ(strides.size(), 2)
+      << "The input strides length is expected to be 2. However, the given 
strides is " << strides;
+  CHECK_EQ(dilation.size(), 2)
+      << "The input dilation length is expected to be 2. However, the given 
dilation is "
+      << dilation;
+  return MakeConv<Conv2DAttrs>(std::move(data), std::move(weight), 
std::move(strides),
+                               std::move(padding), std::move(dilation), 
groups, data_layout,
+                               std::move(kernel_layout), 
out_layout.value_or(data_layout),
+                               out_dtype, /*op_name=*/"relax.nn.conv2d");
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.conv2d").set_body_typed(conv2d);
+
+StructInfo InferStructInfoConv2d(const Call& call, const BlockBuilder& ctx) {
+  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+  TensorStructInfo data_sinfo = input_sinfo[0];
+  TensorStructInfo weight_sinfo = input_sinfo[1];
+
+  const auto* attrs = call->attrs.as<Conv2DAttrs>();
+  auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, 
attrs->data_layout,  //
+                                                    /*tgt_layout=*/"NCHW",     
     //
+                                                    /*tensor_name=*/"data");
+  auto [weight_layout, weight2OIHW] = CheckTensorLayout(call, ctx, 
attrs->kernel_layout,  //
+                                                        /*tgt_layout=*/"OIHW", 
           //
+                                                        
/*tensor_name=*/"kernel");
+  auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx, 
attrs->out_layout,  //
+                                                  /*tgt_layout=*/"NCHW",       
  //
+                                                  /*tensor_name=*/"output");
+
+  Optional<ShapeExpr> data_shape =
+      CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+  Optional<ShapeExpr> weight_shape =
+      CheckNdimPerLayoutAndGetShape(call, ctx, weight_sinfo, weight_layout);
+
+  DataType out_dtype = attrs->out_dtype.is_void()
+                           ? InferBinaryArithOpOutDtype(call, ctx, data_sinfo, 
weight_sinfo)
+                           : attrs->out_dtype;
+  if (!data_shape.defined() || !weight_shape.defined()) {
+    return TensorStructInfo(out_dtype, out_layout.ndim());
+  }
+
+  Array<PrimExpr> data_NCHW_shape = 
data2NCHW.ForwardShape(data_shape.value()->values);
+  Array<PrimExpr> weight_OIHW_shape = 
weight2OIHW.ForwardShape(weight_shape.value()->values);
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  PrimExpr input_channel_data = data_NCHW_shape[1];
+  PrimExpr input_channel_kernel = weight_OIHW_shape[1];
+  if (analyzer->CanProve(input_channel_data != input_channel_kernel * 
attrs->groups)) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "The channel size of the data should equal to the product of input 
channel size of the "
+           "weight and the number of groups. However, the data channel size is 
"
+        << input_channel_data << " while the weight input channel size and 
number of groups are "
+        << input_channel_kernel << " and " << attrs->groups);
+  } else if (!analyzer->CanProveEqual(input_channel_data, input_channel_kernel 
* attrs->groups)) {
+    // Todo(relax-team): Trust the input shape at this moment, and revisit
+    // this condition with runtime shape check
+  }
+  if (analyzer->CanProve(floormod(weight_OIHW_shape[0], attrs->groups) != 0)) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Conv2d expects the number of output channels to be 
divisible by the "
+                        "number of groups. However, the number of output 
channels is "
+                     << weight_OIHW_shape[0] << " while the number of groups 
is " << attrs->groups);
+  } else if (!analyzer->CanProveEqual(floormod(weight_OIHW_shape[0], 
attrs->groups), 0)) {
+    // Todo(relax-team): Trust the input shape at this moment, and revisit
+    // this condition with runtime shape check
+  }
+
+  PrimExpr input_h = data_NCHW_shape[2];
+  PrimExpr input_w = data_NCHW_shape[3];
+  PrimExpr kernel_h = weight_OIHW_shape[2];
+  PrimExpr kernel_w = weight_OIHW_shape[3];
+  PrimExpr padding_h = attrs->padding[0] + attrs->padding[2];
+  PrimExpr padding_w = attrs->padding[1] + attrs->padding[3];
+
+  std::vector<PrimExpr> out_NCHW_shape;
+  out_NCHW_shape.resize(4);
+  out_NCHW_shape[0] = data_NCHW_shape[0];
+  out_NCHW_shape[1] = weight_OIHW_shape[0];
+
+  PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h 
- 1) - 1;
+  PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w 
- 1) - 1;
+  out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, 
attrs->strides[0]) + 1);
+  out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, 
attrs->strides[1]) + 1);
+
+  Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
+  return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
+}
+
+TVM_REGISTER_OP("relax.nn.conv2d")
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("weight", "Tensor", "The weight tensor.")
+    .set_attrs_type<Conv2DAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoConv2d);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/nn/convolution.h b/src/relax/op/nn/convolution.h
new file mode 100644
index 0000000000..a65617b48d
--- /dev/null
+++ b/src/relax/op/nn/convolution.h
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file convolution.h
+ * \brief The functions to make Relax neural network convolution operator 
calls.
+ */
+
+#ifndef TVM_RELAX_OP_NN_CONVOLUTION_H_
+#define TVM_RELAX_OP_NN_CONVOLUTION_H_
+
+#include <tvm/relax/attrs/nn.h>
+
+#include <string>
+#include <utility>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+template <typename T>
+inline Expr MakeConv(Expr data, Expr weight, Array<IntImm> strides, 
Array<IntImm> padding,
+                     Array<IntImm> dilation, int groups, String data_layout, 
String kernel_layout,
+                     String out_layout, DataType out_dtype, std::string 
op_name) {
+  auto attrs = make_object<T>();
+  attrs->strides = ConvertIntImmToInt64(strides);
+  attrs->padding = ConvertIntImmToInt64(padding);
+  attrs->dilation = ConvertIntImmToInt64(dilation);
+  attrs->groups = groups;
+  attrs->data_layout = std::move(data_layout);
+  attrs->kernel_layout = std::move(kernel_layout);
+  attrs->out_layout = std::move(out_layout);
+  attrs->out_dtype = std::move(out_dtype);
+  const Op& op = Op::Get(op_name);
+  return Call(op, {data, weight}, Attrs(attrs), {});
+}
+
+/*! \brief 2D convolution */
+Expr conv2d(Expr data, Expr weight, Array<IntImm> strides, Array<IntImm> 
padding,
+            Array<IntImm> dilation, int groups, String data_layout, String 
kernel_layout,
+            Optional<String> out_layout, DataType out_dtype);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_NN_CONVOLUTION_H_
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
new file mode 100644
index 0000000000..66ae10fe6c
--- /dev/null
+++ b/src/relax/op/nn/nn.cc
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "nn.h"
+
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+/* relax.nn.relu */
+RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(relu, "nn.relu", 
/*require_float_dtype=*/false);
+
+/* relax.nn.gelu */
+RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(gelu, "nn.gelu", 
/*require_float_dtype=*/true);
+
+/* relax.nn.silu */
+RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(silu, "nn.silu", 
/*require_float_dtype=*/true);
+
+/* relax.nn.softmax */
+TVM_REGISTER_NODE_TYPE(SoftmaxAttrs);
+
+Expr softmax(Expr data, int axis) {
+  auto attrs = make_object<SoftmaxAttrs>();
+  attrs->axis = axis;
+  static const Op& op = Op::Get("relax.nn.softmax");
+  return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.softmax").set_body_typed(softmax);
+
+StructInfo InferStructInfoSoftmax(const Call& call, const BlockBuilder& ctx) {
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  if (data_sinfo->IsUnknownNdim()) {
+    return data_sinfo;
+  }
+  if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
+    ctx->ReportFatal(Diagnostic::Error(call) << "Softmax requires the input 
tensor to have float "
+                                                "dtype. However, the given 
input dtype is "
+                                             << data_sinfo->dtype);
+  }
+  const auto* attrs = call->attrs.as<SoftmaxAttrs>();
+  NormalizeAxis(call, ctx, data_sinfo->ndim, attrs->axis);
+
+  return data_sinfo;
+}
+
+TVM_REGISTER_OP("relax.nn.softmax")
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_attrs_type<SoftmaxAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSoftmax);
+
+bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx,
+                            const Array<TensorStructInfo>& input_sinfo, 
Array<Integer> axes) {
+  Op op = Downcast<Op>(call->op);
+  int n_input = op->arguments.size();
+
+  TensorStructInfo data_sinfo = input_sinfo[0];
+
+  std::vector<int> axes_non_neg;
+  if (!data_sinfo->IsUnknownNdim()) {
+    axes_non_neg = NormalizeAxes(call, ctx, data_sinfo->ndim, axes);
+  }
+  int n_axis = axes.size();
+  if (!data_sinfo->IsUnknownDtype() && !data_sinfo->dtype.is_float()) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << op << " requires the input data to have float dtype. However, the 
given data dtype is "
+        << data_sinfo->dtype);
+  }
+  for (int i = 1; i < n_input; ++i) {
+    if (input_sinfo[i]->dtype != data_sinfo->dtype) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << op
+                       << " requires all the input tensors to have the same 
dtype. However, the "
+                       << op->arguments[i]->name << " has dtype " << 
input_sinfo[i]->dtype
+                       << " which is other than the input data's dtype " << 
data_sinfo->dtype);
+    } else if (input_sinfo[i]->ndim != n_axis) {
+      ctx->ReportFatal(Diagnostic::Error(call)
+                       << op << " requires the input " << 
op->arguments[i]->name
+                       << " to have as many dimensions as the length of input 
axes. However, the "
+                          "given one has ndim "
+                       << input_sinfo[i]->ndim << ", which is other than the 
length of axes "
+                       << n_axis);
+    }
+  }
+
+  std::vector<Array<PrimExpr>> axis_lengths;
+  axis_lengths.reserve(n_input);
+  if (const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>()) {
+    std::vector<PrimExpr> lengths;
+    lengths.reserve(n_axis);
+    for (int d = 0; d < n_axis; ++d) {
+      lengths.push_back(data_shape->values[axes_non_neg[d]]);
+    }
+    axis_lengths.push_back(lengths);
+  }
+  for (int i = 1; i < n_input; ++i) {
+    if (const auto* shape = input_sinfo[i]->shape.as<ShapeExprNode>()) {
+      axis_lengths.push_back(shape->values);
+    }
+  }
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  for (int i = 1; i < static_cast<int>(axis_lengths.size()); ++i) {
+    for (int d = 0; d < n_axis; ++d) {
+      if (analyzer->CanProve(axis_lengths[0][d] != axis_lengths[i][d])) {
+        ctx->ReportFatal(Diagnostic::Error(call)
+                         << op
+                         << " requires the input gamma, beta, etc., to have 
size same as the "
+                            "lengths of the data on the given axes. However, 
there exists "
+                         << axis_lengths[0] << " and " << axis_lengths[i] << " 
that are unequal.");
+      } else if (!analyzer->CanProveEqual(axis_lengths[0][d], 
axis_lengths[i][d])) {
+        return true;
+      }
+    }
+  }
+  return false;
+}
+
+/* relax.nn.batch_norm */
+TVM_REGISTER_NODE_TYPE(BatchNormAttrs);
+
+Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr 
moving_var,  //
+                int axis, double epsilon, bool center, bool scale) {
+  ObjectPtr<BatchNormAttrs> attrs = make_object<BatchNormAttrs>();
+  attrs->axis = axis;
+  attrs->epsilon = epsilon;
+  attrs->center = center;
+  attrs->scale = scale;
+
+  static const Op& op = Op::Get("relax.nn.batch_norm");
+  return Call(op,
+              {std::move(data), std::move(gamma), std::move(beta), 
std::move(moving_mean),
+               std::move(moving_var)},
+              Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.batch_norm").set_body_typed(batch_norm);
+
+StructInfo InferStructInfoBatchNorm(const Call& call, const BlockBuilder& ctx) 
{
+  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+
+  const auto* attrs = call->attrs.as<BatchNormAttrs>();
+  bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, 
{attrs->axis});
+
+  DataType dtype = input_sinfo[0]->dtype;
+  if (unknown_shape) {
+    return TupleStructInfo({TensorStructInfo(dtype, input_sinfo[0]->ndim),
+                            TensorStructInfo(dtype, /*ndim=*/1),
+                            TensorStructInfo(dtype, /*ndim=*/1)});
+  } else {
+    return TupleStructInfo({input_sinfo[0], input_sinfo[3], input_sinfo[4]});
+  }
+}
+
+TVM_REGISTER_OP("relax.nn.batch_norm")
+    .set_attrs_type<BatchNormAttrs>()
+    .set_num_inputs(5)
+    .add_argument("data", "Tensor", "Input to which batch_norm will be 
applied.")
+    .add_argument("gamma", "Tensor", "The gamma scale factor.")
+    .add_argument("beta", "Tensor", "The beta offset factor.")
+    .add_argument("moving_mean", "Tensor", "Running mean of input.")
+    .add_argument("moving_var", "Tensor", "Running variance of input.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoBatchNorm);
+
+/* relax.nn.layer_norm */
+TVM_REGISTER_NODE_TYPE(LayerNormAttrs);
+
+Expr layer_norm(Expr data, Expr gamma, Expr beta, Array<Integer> axes, double 
epsilon, bool center,
+                bool scale) {
+  ObjectPtr<LayerNormAttrs> attrs = make_object<LayerNormAttrs>();
+  attrs->axes = std::move(axes);
+  attrs->epsilon = epsilon;
+  attrs->center = center;
+  attrs->scale = scale;
+
+  static const Op& op = Op::Get("relax.nn.layer_norm");
+  return Call(op, {std::move(data), std::move(gamma), std::move(beta)}, 
Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.layer_norm").set_body_typed(layer_norm);
+
+StructInfo InferStructInfoLayerNorm(const Call& call, const BlockBuilder& ctx) 
{
+  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+
+  const auto* attrs = call->attrs.as<LayerNormAttrs>();
+  bool unknown_shape = NormCheckDtypeAndShape(call, ctx, input_sinfo, 
attrs->axes);
+
+  return unknown_shape ? TensorStructInfo(input_sinfo[0]->dtype, 
input_sinfo[0]->ndim)
+                       : input_sinfo[0];
+}
+
+TVM_REGISTER_OP("relax.nn.layer_norm")
+    .set_attrs_type<LayerNormAttrs>()
+    .set_num_inputs(3)
+    .add_argument("data", "Tensor", "Input to which batch_norm will be 
applied.")
+    .add_argument("gamma", "Tensor", "The gamma scale factor.")
+    .add_argument("beta", "Tensor", "The beta offset factor.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoLayerNorm);
+
+/* relax.nn.dropout */
+TVM_REGISTER_NODE_TYPE(DropoutAttrs);
+
+Expr dropout(Expr data, double rate) {
+  ObjectPtr<DropoutAttrs> attrs = make_object<DropoutAttrs>();
+  attrs->rate = rate;
+
+  static const Op& op = Op::Get("relax.nn.dropout");
+  return Call(op, {std::move(data)}, Attrs{attrs}, {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.dropout").set_body_typed(dropout);
+
+StructInfo InferStructInfoDropout(const Call& call, const BlockBuilder& ctx) {
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  return TupleStructInfo({data_sinfo, data_sinfo});
+}
+
+TVM_REGISTER_OP("relax.nn.dropout")
+    .set_attrs_type<DropoutAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "Input to which dropout will be applied.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoDropout);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h
new file mode 100644
index 0000000000..df2b978fc2
--- /dev/null
+++ b/src/relax/op/nn/nn.h
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file nn.h
+ * \brief The functions to make Relax neural network operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_NN_NN_H_
+#define TVM_RELAX_OP_NN_NN_H_
+
+#include <tvm/relax/attrs/nn.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief Quick helper macro to
+ * - expose a make-function interface which construct the call node.
+ * - register op to the registry.
+ * \param OpName The name of operator to register.
+ * \param OpRegName The identifier of the operator in the registry.
+ * \param RequireFloatDtype A boolean indicating if the input is required to 
have float dtype.
+ */
+#define RELAX_REGISTER_UNARY_NN_OP_AND_IMPL(OpName, OpRegName, 
RequireFloatDtype) \
+  RELAX_REGISTER_UNARY_OP(OpRegName).set_attr<FInferStructInfo>(               
   \
+      "FInferStructInfo", InferStructInfoUnaryArith<RequireFloatDtype>);       
   \
+  RELAX_UNARY_OP_INTERFACE(OpName, OpRegName);
+
+/*! \brief Rectified linear unit. */
+Expr relu(Expr data);
+
+/*! \brief Gaussian Error Linear Units function. */
+Expr gelu(Expr data);
+
+/*! \brief Sigmoid Linear Unit function. */
+Expr silu(Expr data);
+
+/*! \brief Softmax function. */
+Expr softmax(Expr data, int axis);
+
+/*! \brief Compute batch normalization. */
+Expr batch_norm(Expr data, Expr gamma, Expr beta, Expr moving_mean, Expr 
moving_var,  //
+                int axis, double epsilon, bool center, bool scale);
+
+/*! \brief Compute layer normalization. */
+Expr layer_norm(Expr data, Expr gamma, Expr beta, Array<Integer> axes, double 
epsilon, bool center,
+                bool scale);
+
+/*!
+ * \brief Applies the dropout operation to the input tensor.
+ * \param data The input data to the operator.
+ * \param rate The probability for an element to be reset to 0.
+ * \return A Tuple of two tensors.
+ * The first one is the original tensor and the second one is a
+ * mask tensor (1.0 where element not dropped, 0.0 where dropped)
+ */
+Expr dropout(Expr data, double rate);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_NN_NN_H_
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
new file mode 100644
index 0000000000..a4c1e6b17d
--- /dev/null
+++ b/src/relax/op/nn/pooling.cc
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include "pooling.h"
+
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+/* relax.nn.max_pool2d */
+TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs);
+
+Expr max_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
+                Array<IntImm> dilation, bool ceil_mode, String layout,
+                Optional<String> out_layout) {
+  padding = GetCompletePadding2D(std::move(padding));
+  if (pool_size.size() == 1) {
+    pool_size.push_back(pool_size[0]);
+  }
+  if (strides.size() == 1) {
+    strides.push_back(strides[0]);
+  }
+  if (dilation.size() == 1) {
+    dilation.push_back(dilation[0]);
+  }
+
+  CHECK_EQ(pool_size.size(), 2)
+      << "The input pool_size length is expected to be 2. However, the given 
pool_size is "
+      << pool_size;
+  CHECK_EQ(strides.size(), 2)
+      << "The input strides length is expected to be 2. However, the given 
strides is " << strides;
+  CHECK_EQ(dilation.size(), 2)
+      << "The input dilation length is expected to be 2. However, the given 
dilation is "
+      << dilation;
+
+  auto attrs = make_object<MaxPool2DAttrs>();
+  attrs->pool_size = std::move(pool_size);
+  attrs->strides = ConvertIntImmToInt64(strides);
+  attrs->padding = ConvertIntImmToInt64(padding);
+  attrs->dilation = ConvertIntImmToInt64(dilation);
+  attrs->ceil_mode = ceil_mode;
+  attrs->layout = layout;
+  attrs->out_layout = out_layout.value_or(layout);
+  static const Op& op = Op::Get("relax.nn.max_pool2d");
+  return Call(op, {std::move(data)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(max_pool2d);
+
+StructInfo InferStructInfoMaxPool2D(const Call& call, const BlockBuilder& ctx) 
{
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+
+  const auto* attrs = call->attrs.as<MaxPool2DAttrs>();
+  auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout,  
//
+                                                    /*tgt_layout=*/"NCHW",     
//
+                                                    /*tensor_name=*/"data");
+  auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx, 
attrs->out_layout,  //
+                                                  /*tgt_layout=*/"NCHW",       
  //
+                                                  /*tensor_name=*/"output");
+
+  Optional<ShapeExpr> data_shape =
+      CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+  if (!data_shape.defined()) {
+    return TensorStructInfo(data_sinfo->dtype, out_layout.ndim());
+  }
+
+  Array<PrimExpr> data_NCHW_shape = 
data2NCHW.ForwardShape(data_shape.value()->values);
+
+  PrimExpr input_h = data_NCHW_shape[2];
+  PrimExpr input_w = data_NCHW_shape[3];
+  PrimExpr kernel_h = attrs->pool_size[0];
+  PrimExpr kernel_w = attrs->pool_size[1];
+  PrimExpr padding_h = attrs->padding[0] + attrs->padding[2];
+  PrimExpr padding_w = attrs->padding[1] + attrs->padding[3];
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  std::vector<PrimExpr> out_NCHW_shape;
+  out_NCHW_shape.resize(4);
+  out_NCHW_shape[0] = data_NCHW_shape[0];
+  out_NCHW_shape[1] = data_NCHW_shape[1];
+
+  PrimExpr numerator_h = input_h + padding_h - attrs->dilation[0] * (kernel_h 
- 1) - 1;
+  PrimExpr numerator_w = input_w + padding_w - attrs->dilation[1] * (kernel_w 
- 1) - 1;
+  if (attrs->ceil_mode) {
+    numerator_h += attrs->strides[0] - 1;
+    numerator_w += attrs->strides[1] - 1;
+  }
+  out_NCHW_shape[2] = analyzer->Simplify(floordiv(numerator_h, 
attrs->strides[0]) + 1);
+  out_NCHW_shape[3] = analyzer->Simplify(floordiv(numerator_w, 
attrs->strides[1]) + 1);
+
+  Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
+  return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
+}
+
+TVM_REGISTER_OP("relax.nn.max_pool2d")
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor")
+    .set_attrs_type<MaxPool2DAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMaxPool2D);
+
+/* relax.nn.adaptive_avg_pool2d */
+TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs);
+
+Expr adaptive_avg_pool2d(Expr data, Optional<Array<IntImm>> output_size, 
String layout,
+                         Optional<String> out_layout) {
+  ObjectPtr<AdaptivePool2DAttrs> attrs = make_object<AdaptivePool2DAttrs>();
+  attrs->layout = layout;
+  attrs->out_layout = out_layout.value_or(layout);
+  if (output_size.defined()) {
+    Array<IntImm> _output_size = output_size.value();
+    if (_output_size.size() == 1) {
+      _output_size.push_back(_output_size[0]);
+    }
+    CHECK_EQ(_output_size.size(), 2)
+        << "The output_size length is expected to be 2. However, the given 
output_size is "
+        << _output_size;
+    attrs->output_size = std::move(_output_size);
+  }
+
+  static const Op& op = Op::Get("relax.nn.adaptive_avg_pool2d");
+  return Call(op, {std::move(data)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.adaptive_avg_pool2d").set_body_typed(adaptive_avg_pool2d);
+
+StructInfo InferStructInfoAdaptiveAvgPool2D(const Call& call, const 
BlockBuilder& ctx) {
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+
+  const auto* attrs = call->attrs.as<AdaptivePool2DAttrs>();
+  auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout,  
//
+                                                    /*tgt_layout=*/"NCHW",     
//
+                                                    /*tensor_name=*/"data");
+  auto [out_layout, out2NCHW] = CheckTensorLayout(call, ctx, 
attrs->out_layout,  //
+                                                  /*tgt_layout=*/"NCHW",       
  //
+                                                  /*tensor_name=*/"output");
+
+  Optional<ShapeExpr> data_shape =
+      CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+  if (!data_shape.defined()) {
+    if (data_sinfo->shape.defined() && attrs->out_layout == attrs->layout &&
+        !attrs->output_size.defined()) {
+      return data_sinfo;
+    } else {
+      return TensorStructInfo(data_sinfo->dtype, out_layout.ndim());
+    }
+  }
+
+  Array<PrimExpr> data_NCHW_shape = 
data2NCHW.ForwardShape(data_shape.value()->values);
+  Array<PrimExpr> out_NCHW_shape(data_NCHW_shape);
+  if (attrs->output_size.defined()) {
+    out_NCHW_shape.Set(2, attrs->output_size.value()[0]);
+    out_NCHW_shape.Set(3, attrs->output_size.value()[1]);
+  }
+
+  Array<PrimExpr> out_shape = out2NCHW.BackwardShape(out_NCHW_shape);
+  return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype);
+}
+
+TVM_REGISTER_OP("relax.nn.adaptive_avg_pool2d")
+    .set_attrs_type<AdaptivePool2DAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoAdaptiveAvgPool2D);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h
new file mode 100644
index 0000000000..3c1792d21f
--- /dev/null
+++ b/src/relax/op/nn/pooling.h
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file pooling.h
+ * \brief The functions to make Relax neural network pooling operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_NN_POOLING_H_
+#define TVM_RELAX_OP_NN_POOLING_H_
+
+#include <tvm/relax/attrs/nn.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief 2D maximum pooling operator. */
+Expr max_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
+                Array<IntImm> dilation, bool ceil_mode, String layout, 
Optional<String> out_layout);
+
+/*! \brief 2D adaptive average pooling operator. */
+Expr adaptive_avg_pool2d(Expr data, Optional<Array<IntImm>> output_size, 
String layout,
+                         Optional<String> out_layout);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_NN_POOLING_H_
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
new file mode 100644
index 0000000000..d047448309
--- /dev/null
+++ b/tests/python/relax/test_op_nn.py
@@ -0,0 +1,929 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax, tir
+from tvm import TVMError
+from tvm.ir import Op
+from tvm.script import relax as R
+
+
+def test_op_correctness():
+    x = relax.Var("x", R.Tensor((2, 3), "float32"))
+    assert relax.op.nn.relu(x).op == Op.get("relax.nn.relu")
+    assert relax.op.nn.gelu(x).op == Op.get("relax.nn.gelu")
+    assert relax.op.nn.silu(x).op == Op.get("relax.nn.silu")
+    assert relax.op.nn.softmax(x).op == Op.get("relax.nn.softmax")
+    assert relax.op.nn.dropout(x).op == Op.get("relax.nn.dropout")
+
+    x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    gamma = relax.Var("gamma", R.Tensor((3,), "float32"))
+    beta = relax.Var("beta", R.Tensor((3,), "float32"))
+    moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32"))
+    moving_var = relax.Var("moving_var", R.Tensor((3,), "float32"))
+    assert relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, 
axis=1).op == Op.get(
+        "relax.nn.batch_norm"
+    )
+    assert relax.op.nn.layer_norm(x, gamma, beta, axes=1).op == 
Op.get("relax.nn.layer_norm")
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
+    ret = bb.normalize(call)
+    tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_linear_unit_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
+    x3 = relax.Var("x", R.Tensor((2, 3)))
+    x4 = relax.Var("x", R.Tensor())
+
+    _check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3), 
"float32"))
+    _check_inference(bb, relax.op.nn.silu(x1), 
relax.TensorStructInfo(dtype="float32", ndim=3))
+    _check_inference(bb, relax.op.nn.gelu(x2), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.nn.relu(x3), relax.TensorStructInfo((2, 3), 
dtype=""))
+    _check_inference(bb, relax.op.nn.gelu(x4), 
relax.TensorStructInfo(dtype=""))
+
+
+def test_linear_unit_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    x0 = relax.Var("x", R.Tensor((m, n), "float32"))
+    x1 = relax.Var("x", R.Tensor((4, n), "float32"))
+
+    _check_inference(bb, relax.op.nn.silu(x0), relax.TensorStructInfo((m, n), 
"float32"))
+    _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((4, n), 
"float32"))
+
+
+def test_linear_unit_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
+    s1 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+
+    _check_inference(bb, relax.op.nn.gelu(x0), relax.TensorStructInfo(s0, 
"float32"))
+    _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo(s1, 
"float32"))
+
+
+def test_linear_unit_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "float64"))
+    x1 = relax.Var("x", R.Tensor((2, 3), "int8"))
+    x2 = relax.Var("x", R.Tensor((2, 3), "int64"))
+
+    _check_inference(bb, relax.op.nn.relu(x0), relax.TensorStructInfo((2, 3), 
"float64"))
+    _check_inference(bb, relax.op.nn.relu(x1), relax.TensorStructInfo((2, 3), 
"int8"))
+    _check_inference(bb, relax.op.nn.relu(x2), relax.TensorStructInfo((2, 3), 
"int64"))
+
+
+def test_linear_unit_infer_struct_info_invalid_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "int8"))
+    x1 = relax.Var("x", R.Tensor((2, 3), "int64"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.gelu(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.silu(x1))
+
+
+def test_linear_unit_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.gelu(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.silu(x1))
+
+
+def test_softmax_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
+    x3 = relax.Var("x", R.Tensor((2, 3)))
+    x4 = relax.Var("x", R.Tensor())
+
+    _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 
3), "float32"))
+    _check_inference(
+        bb, relax.op.nn.softmax(x1, axis=0), 
relax.TensorStructInfo(dtype="float32", ndim=3)
+    )
+    _check_inference(bb, relax.op.nn.softmax(x2, axis=1), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.nn.softmax(x3, axis=-1), 
relax.TensorStructInfo((2, 3), dtype=""))
+    _check_inference(bb, relax.op.nn.softmax(x4, axis=-2), 
relax.TensorStructInfo(dtype=""))
+
+
+def test_softmax_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    x0 = relax.Var("x", R.Tensor((m, n), "float32"))
+    x1 = relax.Var("x", R.Tensor((4, n), "float32"))
+
+    _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((m, 
n), "float32"))
+    _check_inference(bb, relax.op.nn.softmax(x1, axis=0), 
relax.TensorStructInfo((4, n), "float32"))
+
+
+def test_softmax_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
+    s1 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+
+    _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo(s0, 
"float32"))
+    _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo(s1, 
"float32"))
+
+
+def test_softmax_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3), "float64"))
+
+    _check_inference(bb, relax.op.nn.softmax(x0), relax.TensorStructInfo((2, 
3), "float16"))
+    _check_inference(bb, relax.op.nn.softmax(x1), relax.TensorStructInfo((2, 
3), "float64"))
+
+
+def test_softmax_infer_struct_info_invalid_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "int8"))
+    x1 = relax.Var("x", R.Tensor((2, 3), "int64"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.softmax(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.softmax(x1))
+
+
+def test_softmax_infer_struct_info_axis_out_of_range():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.softmax(x, axis=3))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.softmax(x, axis=-4))
+
+
+def test_softmax_wrong_with_multiple_axes():
+    x = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    with pytest.raises(TVMError):
+        relax.op.nn.softmax(x, axis=[1, 2])
+    with pytest.raises(TVMError):
+        relax.op.nn.softmax(x, axis=[-1, -2, -3])
+
+
+def test_softmax_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.softmax(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.softmax(x1))
+
+
+def test_batch_norm_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor(ndim=4))
+    x4 = relax.Var("x", R.Tensor())
+    gamma0 = relax.Var("gamma", R.Tensor((3,), "float32"))
+    gamma1 = relax.Var("gamma", R.Tensor("float32", ndim=1))
+    gamma2 = relax.Var("gamma", R.Tensor(ndim=1))
+    beta0 = relax.Var("beta", R.Tensor((3,), "float32"))
+    beta1 = relax.Var("beta", R.Tensor((3,)))
+    moving_mean0 = relax.Var("moving_mean", R.Tensor((3,), "float32"))
+    moving_mean1 = relax.Var("moving_mean", R.Tensor((3,)))
+    moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32"))
+    moving_var1 = relax.Var("moving_var", R.Tensor("float32", ndim=1))
+    moving_var2 = relax.Var("moving_var", R.Tensor(ndim=1))
+
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var0, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2, 3, 28, 28), "float32"),
+                relax.TensorStructInfo((3,), "float32"),
+                relax.TensorStructInfo((3,), "float32"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var0, 
axis=-3),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2, 3, 28, 28), "float32"),
+                relax.TensorStructInfo((3,), "float32"),
+                relax.TensorStructInfo((3,), "float32"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x1, gamma0, beta0, moving_mean0, moving_var0, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=4),
+                relax.TensorStructInfo((3,), "float32"),
+                relax.TensorStructInfo((3,), "float32"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x0, gamma1, beta0, moving_mean0, moving_var0, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2, 3, 28, 28), "float32"),
+                relax.TensorStructInfo((3,), "float32"),
+                relax.TensorStructInfo((3,), "float32"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, moving_var1, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2, 3, 28, 28), "float32"),
+                relax.TensorStructInfo((3,), "float32"),
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x1, gamma1, beta0, moving_mean0, moving_var1, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=4),
+                relax.TensorStructInfo((3,), "float32"),
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x2, gamma1, beta0, moving_mean0, moving_var1, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32"),
+                relax.TensorStructInfo((3,), "float32"),
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x3, gamma2, beta1, moving_mean1, moving_var2, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(ndim=4, dtype=""),
+                relax.TensorStructInfo((3,), dtype=""),
+                relax.TensorStructInfo(dtype="", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x4, gamma2, beta1, moving_mean1, moving_var2, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype=""),
+                relax.TensorStructInfo((3,), dtype=""),
+                relax.TensorStructInfo(dtype="", ndim=1),
+            ]
+        ),
+    )
+
+
+def test_batch_norm_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    c0 = tir.Var("c", "int64")
+    c1 = tir.Var("c", "int64")
+    h = tir.Var("h", "int64")
+    w = tir.Var("w", "int64")
+    x0 = relax.Var("x", R.Tensor((n, c0, h, w), "float32"))
+    x1 = relax.Var("x", R.Tensor((n, c1, h, w), "float32"))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=4))
+    gamma0 = relax.Var("gamma", R.Tensor((c0,), "float32"))
+    gamma1 = relax.Var("gamma", R.Tensor((c1,), "float32"))
+    gamma2 = relax.Var("gamma", R.Tensor("float32", ndim=1))
+    beta = relax.Var("beta", R.Tensor((c0,), "float32"))
+    moving_mean = relax.Var("moving_mean", R.Tensor((c0,), "float32"))
+    moving_var0 = relax.Var("moving_var", R.Tensor((c0,), "float32"))
+    moving_var1 = relax.Var("moving_var", R.Tensor((c1,), "float32"))
+    moving_var2 = relax.Var("moving_var", R.Tensor("float32", ndim=1))
+
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var0, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((n, c0, h, w), "float32"),
+                relax.TensorStructInfo((c0,), "float32"),
+                relax.TensorStructInfo((c0,), "float32"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x1, gamma0, beta, moving_mean, moving_var0, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=4),
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x2, gamma0, beta, moving_mean, moving_var0, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=4),
+                relax.TensorStructInfo((c0,), "float32"),
+                relax.TensorStructInfo((c0,), "float32"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x0, gamma1, beta, moving_mean, moving_var0, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=4),
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var1, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=4),
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x0, gamma2, beta, moving_mean, moving_var0, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((n, c0, h, w), "float32"),
+                relax.TensorStructInfo((c0,), "float32"),
+                relax.TensorStructInfo((c0,), "float32"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, moving_var2, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((n, c0, h, w), "float32"),
+                relax.TensorStructInfo((c0,), "float32"),
+                relax.TensorStructInfo(dtype="float32", ndim=1),
+            ]
+        ),
+    )
+
+
+def test_batch_norm_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4))
+    s1 = relax.Var("s1", relax.ShapeStructInfo())
+    s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=1))
+    s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=1))
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32"))
+    beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32"))
+    moving_mean = relax.Var("moving_mean", relax.TensorStructInfo(s2, 
"float32"))
+    moving_var = relax.Var("moving_var", relax.TensorStructInfo(s3, "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x0, gamma, beta, moving_mean, moving_var, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(s0, "float32"),
+                relax.TensorStructInfo(s2, "float32"),
+                relax.TensorStructInfo(s3, "float32"),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x1, gamma, beta, moving_mean, moving_var, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(s1, "float32"),
+                relax.TensorStructInfo(s2, "float32"),
+                relax.TensorStructInfo(s3, "float32"),
+            ]
+        ),
+    )
+
+
+def test_batch_norm_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16"))
+    gamma = relax.Var("gamma", R.Tensor((3,), "float16"))
+    beta = relax.Var("beta", R.Tensor((3,), "float16"))
+    moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float16"))
+    moving_var = relax.Var("moving_var", R.Tensor((3,), "float16"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, 
axis=1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo((2, 3, 28, 28), "float16"),
+                relax.TensorStructInfo((3,), "float16"),
+                relax.TensorStructInfo((3,), "float16"),
+            ]
+        ),
+    )
+
+
+def test_batch_norm_infer_struct_info_invalid_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8"))
+    gamma0 = relax.Var("gamma", R.Tensor((3,), "int8"))
+    beta0 = relax.Var("beta", R.Tensor((3,), "int8"))
+    moving_mean0 = relax.Var("moving_mean", R.Tensor((3,), "int8"))
+    moving_var0 = relax.Var("moving_var", R.Tensor((3,), "int8"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int32"))
+    gamma1 = relax.Var("gamma", R.Tensor((3,), "int32"))
+    beta1 = relax.Var("beta", R.Tensor((3,), "int32"))
+    moving_mean1 = relax.Var("moving_mean", R.Tensor((3,), "int32"))
+    moving_var1 = relax.Var("moving_var", R.Tensor((3,), "int32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, 
moving_var0, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_norm(x1, gamma1, beta1, moving_mean1, 
moving_var1, axis=1))
+
+
+def test_batch_norm_infer_struct_info_axis_out_of_range():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    gamma = relax.Var("gamma", R.Tensor((3,), "float32"))
+    beta = relax.Var("beta", R.Tensor((3,), "float32"))
+    moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32"))
+    moving_var = relax.Var("moving_var", R.Tensor((3,), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_norm(x, gamma, beta, moving_mean, 
moving_var, axis=4))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_norm(x, gamma, beta, moving_mean, 
moving_var, axis=-5))
+
+
+def test_batch_norm_infer_struct_info_dtype_mismatch():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8"))
+    gamma0 = relax.Var("gamma", R.Tensor((3,), "float32"))
+    gamma1 = relax.Var("gamma", R.Tensor((3,)))
+    beta = relax.Var("beta", R.Tensor((3,), "float32"))
+    moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32"))
+    moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32"))
+    moving_var1 = relax.Var("moving_var", R.Tensor((3,), "float16"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_norm(x1, gamma0, beta, moving_mean, 
moving_var0, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_norm(x0, gamma1, beta, moving_mean, 
moving_var0, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_norm(x0, gamma0, beta, moving_mean, 
moving_var1, axis=1))
+
+
+def test_batch_norm_infer_struct_info_ndim_mismatch():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    gamma0 = relax.Var("gamma", R.Tensor((3,), "float32"))
+    gamma1 = relax.Var("gamma", R.Tensor((3, 1), "float32"))
+    beta = relax.Var("beta", R.Tensor((3,), "float32"))
+    moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32"))
+    moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32"))
+    moving_var1 = relax.Var("moving_var", R.Tensor((1, 3), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_norm(x, gamma1, beta, moving_mean, 
moving_var0, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_norm(x, gamma0, beta, moving_mean, 
moving_var1, axis=1))
+
+
+def test_batch_norm_infer_struct_info_shape_mismatch():
+    bb = relax.BlockBuilder()
+    c = tir.Var("c", "int64")
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, c, 28, 28), "float32"))
+    gamma0 = relax.Var("gamma", R.Tensor((3,), "float32"))
+    gamma1 = relax.Var("gamma", R.Tensor((4,), "float32"))
+    gamma2 = relax.Var("gamma", R.Tensor((c + 2,), "float32"))
+    beta0 = relax.Var("beta", R.Tensor((3,), "float32"))
+    beta1 = relax.Var("beta", R.Tensor((c,), "float32"))
+    moving_mean0 = relax.Var("moving_mean", R.Tensor((3,), "float32"))
+    moving_mean1 = relax.Var("moving_mean", R.Tensor((c,), "float32"))
+    moving_var0 = relax.Var("moving_var", R.Tensor((3,), "float32"))
+    moving_var1 = relax.Var("moving_var", R.Tensor((4,), "float32"))
+    moving_var2 = relax.Var("moving_var", R.Tensor((c,), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_norm(x0, gamma1, beta0, moving_mean0, 
moving_var0, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_norm(x0, gamma0, beta0, moving_mean0, 
moving_var1, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_norm(x1, gamma2, beta1, moving_mean1, 
moving_var2, axis=1))
+
+
+def test_batch_norm_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28)))
+    gamma0 = relax.Var("gamma", R.Tensor((3,), "float32"))
+    gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((3,), 
"float32")))
+    beta = relax.Var("beta", R.Tensor((3,), "float32"))
+    moving_mean = relax.Var("moving_mean", R.Tensor((3,), "float32"))
+    moving_var = relax.Var("moving_var", R.Tensor((3,), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_norm(x1, gamma0, beta, moving_mean, 
moving_var, axis=1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_norm(x0, gamma1, beta, moving_mean, 
moving_var, axis=1))
+
+
+def test_layer_norm_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
+    gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32"))
+    gamma1 = relax.Var("gamma", R.Tensor("float32", ndim=2))
+    gamma2 = relax.Var("gamma", R.Tensor((4, 5)))
+    beta0 = relax.Var("beta", R.Tensor((4, 5), "float32"))
+    beta1 = relax.Var("beta", R.Tensor((4, 5)))
+
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1]),
+        relax.TensorStructInfo((2, 3, 4, 5), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, 3]),
+        relax.TensorStructInfo((2, 3, 4, 5), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x1, gamma0, beta0, axes=[-2, -1]),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x2, gamma0, beta0, axes=[-2, -1]),
+        relax.TensorStructInfo(dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x0, gamma1, beta0, axes=[-2, -1]),
+        relax.TensorStructInfo((2, 3, 4, 5), dtype="float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x3, gamma2, beta1, axes=[-2, -1]),
+        relax.TensorStructInfo((2, 3, 4, 5), dtype=""),
+    )
+
+
+def test_layer_norm_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    c0 = tir.Var("c", "int64")
+    c1 = tir.Var("c", "int64")
+    x0 = relax.Var("x", R.Tensor((n, a, b, c0), "float32"))
+    x1 = relax.Var("x", R.Tensor((n, a, b, c1), "float32"))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=4))
+    gamma0 = relax.Var("gamma", R.Tensor((b, c0), "float32"))
+    gamma1 = relax.Var("gamma", R.Tensor((b, c1), "float32"))
+    beta = relax.Var("beta", R.Tensor((b, c0), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x0, gamma0, beta, axes=[-2, -1]),
+        relax.TensorStructInfo((n, a, b, c0), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x1, gamma0, beta, axes=[-2, -1]),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x0, gamma1, beta, axes=[-2, -1]),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x2, gamma0, beta, axes=[-2, -1]),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x2, gamma1, beta, axes=[-2, -1]),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+
+
+def test_layer_norm_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4))
+    s1 = relax.Var("s1", relax.ShapeStructInfo())
+    s2 = relax.Var("s2", relax.ShapeStructInfo(ndim=2))
+    s3 = relax.Var("s3", relax.ShapeStructInfo(ndim=2))
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    gamma = relax.Var("gamma", relax.TensorStructInfo(s2, "float32"))
+    beta = relax.Var("beta", relax.TensorStructInfo(s3, "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x0, gamma, beta, axes=[2, 3]),
+        relax.TensorStructInfo(s0, "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x1, gamma, beta, axes=[2, 3]),
+        relax.TensorStructInfo(s1, "float32"),
+    )
+
+
+def test_layer_norm_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float16"))
+    gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float16"))
+    beta0 = relax.Var("beta", R.Tensor((4, 5), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float64"))
+    gamma1 = relax.Var("gamma", R.Tensor((4, 5), "float64"))
+    beta1 = relax.Var("beta", R.Tensor((4, 5), "float64"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1]),
+        relax.TensorStructInfo((2, 3, 4, 5), "float16"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1]),
+        relax.TensorStructInfo((2, 3, 4, 5), "float64"),
+    )
+
+
+def test_layer_norm_infer_struct_info_invalid_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int8"))
+    gamma0 = relax.Var("gamma", R.Tensor((4, 5), "int8"))
+    beta0 = relax.Var("beta", R.Tensor((4, 5), "int8"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 4, 5), "int32"))
+    gamma1 = relax.Var("gamma", R.Tensor((4, 5), "int32"))
+    beta1 = relax.Var("beta", R.Tensor((4, 5), "int32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1]))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1]))
+
+
+def test_layer_norm_infer_struct_info_axis_out_of_range_and_repetitive():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    gamma = relax.Var("gamma", R.Tensor((4, 5), "float32"))
+    beta = relax.Var("beta", R.Tensor((4, 5), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.layer_norm(x, gamma, beta, axes=[3, 4]))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.layer_norm(x, gamma, beta, axes=[3, -1]))
+
+
+def test_layer_norm_infer_struct_info_dtype_mismatch():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32"))
+    gamma1 = relax.Var("gamma", R.Tensor((4, 5), "int8"))
+    beta0 = relax.Var("beta", R.Tensor((4, 5), "float32"))
+    beta1 = relax.Var("beta", R.Tensor((4, 5)))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.layer_norm(x, gamma1, beta0, axes=[-2, -1]))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.layer_norm(x, gamma0, beta1, axes=[-2, -1]))
+
+
+def test_layer_norm_infer_struct_info_ndim_mismatch():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32"))
+    gamma1 = relax.Var("gamma", R.Tensor((4,), "float32"))
+    beta0 = relax.Var("beta", R.Tensor((4, 5), "float32"))
+    beta1 = relax.Var("beta", R.Tensor((3, 4, 5), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.layer_norm(x, gamma1, beta0, axes=[-2, -1]))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.layer_norm(x, gamma0, beta1, axes=[-2, -1]))
+
+
+def test_layer_norm_infer_struct_info_shape_mismatch():
+    bb = relax.BlockBuilder()
+    c0 = tir.Var("c", "int64")
+    x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 4, c0), "float32"))
+    gamma0 = relax.Var("gamma", R.Tensor((4, 6), "float32"))
+    gamma1 = relax.Var("gamma", R.Tensor((4, c0), "float32"))
+    beta0 = relax.Var("beta", R.Tensor((4, 5), "float32"))
+    beta1 = relax.Var("beta", R.Tensor((4, c0 - 2), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.layer_norm(x0, gamma0, beta0, axes=[-2, -1]))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.layer_norm(x1, gamma1, beta1, axes=[-2, -1]))
+
+
+def test_layer_norm_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 4, 5)))
+    gamma0 = relax.Var("gamma", R.Tensor((4, 5), "float32"))
+    gamma1 = relax.Var("gamma", relax.FuncStructInfo([], R.Tensor((4, 5), 
"float32")))
+    beta = relax.Var("beta", R.Tensor((4, 5), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.layer_norm(x1, gamma0, beta, axes=[-2, -1]))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.layer_norm(x0, gamma1, beta, axes=[-2, -1]))
+
+
+def test_dropout_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
+    x3 = relax.Var("x", R.Tensor((2, 3)))
+    x4 = relax.Var("x", R.Tensor())
+
+    _check_inference(
+        bb,
+        relax.op.nn.dropout(x0),
+        relax.TupleStructInfo(
+            [relax.TensorStructInfo((2, 3), "float32"), 
relax.TensorStructInfo((2, 3), "float32")]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.dropout(x1),
+        relax.TupleStructInfo(
+            [
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+                relax.TensorStructInfo(dtype="float32", ndim=3),
+            ]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.dropout(x2),
+        relax.TupleStructInfo(
+            [relax.TensorStructInfo(dtype="float32"), 
relax.TensorStructInfo(dtype="float32")]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.dropout(x3),
+        relax.TupleStructInfo(
+            [relax.TensorStructInfo((2, 3), dtype=""), 
relax.TensorStructInfo((2, 3), dtype="")]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.dropout(x4),
+        relax.TupleStructInfo([relax.TensorStructInfo(dtype=""), 
relax.TensorStructInfo(dtype="")]),
+    )
+
+
+def test_dropout_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    x = relax.Var("x", R.Tensor((m, n), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.dropout(x),
+        relax.TupleStructInfo(
+            [relax.TensorStructInfo((m, n), "float32"), 
relax.TensorStructInfo((m, n), "float32")]
+        ),
+    )
+
+
+def test_dropout_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
+    s1 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.dropout(x0),
+        relax.TupleStructInfo(
+            [relax.TensorStructInfo(s0, "float32"), relax.TensorStructInfo(s0, 
"float32")]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.dropout(x1),
+        relax.TupleStructInfo(
+            [relax.TensorStructInfo(s1, "float32"), relax.TensorStructInfo(s1, 
"float32")]
+        ),
+    )
+
+
+def test_dropout_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "float64"))
+    x1 = relax.Var("x", R.Tensor((2, 3), "int8"))
+    x2 = relax.Var("x", R.Tensor((2, 3), "int64"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.dropout(x0),
+        relax.TupleStructInfo(
+            [relax.TensorStructInfo((2, 3), "float64"), 
relax.TensorStructInfo((2, 3), "float64")]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.dropout(x1),
+        relax.TupleStructInfo(
+            [relax.TensorStructInfo((2, 3), "int8"), 
relax.TensorStructInfo((2, 3), "int8")]
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.dropout(x2),
+        relax.TupleStructInfo(
+            [relax.TensorStructInfo((2, 3), "int64"), 
relax.TensorStructInfo((2, 3), "int64")]
+        ),
+    )
+
+
+def test_dropout_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.dropout(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.dropout(x1))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_op_nn_convolution.py 
b/tests/python/relax/test_op_nn_convolution.py
new file mode 100644
index 0000000000..6533d43420
--- /dev/null
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -0,0 +1,429 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax, tir
+from tvm import TVMError
+from tvm.ir import Op
+from tvm.script import relax as R
+
+
+def test_op_correctness():
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
+    assert relax.op.nn.conv2d(x, w).op == Op.get("relax.nn.conv2d")
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
+    ret = bb.normalize(call)
+    tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_conv2d_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 28, 28, 3), "float32"))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x3 = relax.Var("x", R.Tensor("float32"))
+    x4 = relax.Var("x", R.Tensor())
+    x5 = relax.Var("x", R.Tensor((2, 4, 28, 28, 16), "float32"))
+    w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
+    w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float32"))
+    w2 = relax.Var("w", R.Tensor("float32", ndim=4))
+    w3 = relax.Var("w", R.Tensor("float32"))
+    w4 = relax.Var("w", R.Tensor((48, 4, 3, 3, 16), "float32"))
+
+    _check_inference(
+        bb, relax.op.nn.conv2d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26), 
"float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w0, out_dtype="float16"),
+        relax.TensorStructInfo((2, 4, 26, 26), "float16"),
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d(x0, w0, padding=1), relax.TensorStructInfo((2, 
4, 28, 28), "float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w0, padding=[1, 2]),
+        relax.TensorStructInfo((2, 4, 28, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w0, padding=[1, 2, 3, 4]),
+        relax.TensorStructInfo((2, 4, 30, 32), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w0, strides=2),
+        relax.TensorStructInfo((2, 4, 13, 13), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w0, strides=(2, 3)),
+        relax.TensorStructInfo((2, 4, 13, 9), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w0, dilation=2),
+        relax.TensorStructInfo((2, 4, 24, 24), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w0, dilation=(2, 1)),
+        relax.TensorStructInfo((2, 4, 24, 26), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x1, w0, data_layout="NHWC"),
+        relax.TensorStructInfo((2, 26, 26, 4), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w0, out_layout="NHWC"),
+        relax.TensorStructInfo((2, 26, 26, 4), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w1, kernel_layout="IOHW"),
+        relax.TensorStructInfo((2, 4, 26, 26), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(
+            x5, w4, data_layout="NCHW16c", kernel_layout="OIHW16i", 
out_layout="NHWC16c"
+        ),
+        relax.TensorStructInfo((2, 26, 26, 3, 16), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d(x2, w0), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d(x3, w0), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d(x0, w2), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d(x0, w3), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(bb, relax.op.nn.conv2d(x4, w0), 
relax.TensorStructInfo(dtype="", ndim=4))
+
+
+def test_conv2d_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    c = tir.Var("c", "int64")
+    c16 = tir.Var("c16", "int64")
+    ih = tir.Var("ih", "int64")
+    iw = tir.Var("iw", "int64")
+    ki = tir.Var("ki", "int64")
+    ko = tir.Var("ko", "int64")
+    kh = tir.Var("kh", "int64")
+    kw = tir.Var("kw", "int64")
+    x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32"))
+    x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32"))
+    w0 = relax.Var("w", R.Tensor((ko, ki, kh, kw), "float32"))
+    w1 = relax.Var("w", R.Tensor((ko, c, kh, kw), "float32"))
+    w2 = relax.Var("w", R.Tensor((ko, c, kh, kw, c16), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w0),
+        relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w1),
+        relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(
+            x1, w2, data_layout="NCHW16c", kernel_layout="OIHW16i", 
out_layout="NCHW"
+        ),
+        relax.TensorStructInfo((n, ko, ih + 1 - kh, iw + 1 - kw), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w0, strides=(2, 2), padding=(1, 1), 
dilation=(2, 2)),
+        relax.TensorStructInfo(
+            (n, ko, tvm.tir.floordiv(ih + 3, 2) + 1 - kh, tvm.tir.floordiv(iw 
+ 3, 2) + 1 - kw),
+            "float32",
+        ),
+    )
+
+
+def test_conv2d_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+    s2 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+    s3 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(s3, "float32"))
+    w = relax.Var("w", relax.TensorStructInfo(s2, "float32"))
+
+    _check_inference(bb, relax.op.nn.conv2d(x0, w), 
relax.TensorStructInfo(dtype="float32", ndim=4))
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x1, w, data_layout="NCHW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=5),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w, out_layout="NCHW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=5),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x2, w),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+
+
+def test_conv2d_infer_struct_info_groups():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 8, 28, 28, 16), "float32"))
+    w0 = relax.Var("w", R.Tensor((48, 16, 3, 3), "float32"))
+    w1 = relax.Var("w", R.Tensor((48, 2, 3, 3, 8), "float32"))
+
+    _check_inference(
+        bb, relax.op.nn.conv2d(x0, w0, groups=8), relax.TensorStructInfo((2, 
48, 26, 26), "float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w1, kernel_layout="OIHW8i", groups=8),
+        relax.TensorStructInfo((2, 48, 26, 26), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x1, w0, data_layout="NCHW16c", groups=8),
+        relax.TensorStructInfo((2, 3, 26, 26, 16), "float32"),
+    )
+
+
+def test_conv2d_infer_struct_info_symbolic_groups():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    ic = tir.Var("c", "int64")
+    oc = tir.Var("oc", "int64")
+    x = relax.Var("x", R.Tensor((n, ic * 4, 28, 28), "float32"))
+    w0 = relax.Var("w", R.Tensor((oc * 4, ic, 3, 3), "float32"))
+    w1 = relax.Var("w", R.Tensor((oc, ic, 3, 3), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x, w0, groups=4),
+        relax.TensorStructInfo((n, oc * 4, 26, 26), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d(x, w1, groups=4), relax.TensorStructInfo((n, 
oc, 26, 26), "float32")
+    )
+
+
+def test_conv2d_infer_struct_info_input_channel_group_incompatible():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    ic = tir.Var("c", "int64")
+    oc = tir.Var("oc", "int64")
+    x0 = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32"))
+    w0 = relax.Var("w", R.Tensor((48, 20, 3, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor((n, ic * 6, 28, 28), "float32"))
+    w1 = relax.Var("w", R.Tensor((oc, ic - 1, 3, 3), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x0, w0, groups=6))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x1, w1, groups=6))
+
+
+def test_conv2d_infer_struct_info_output_channel_group_incompatible():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    ic = tir.Var("c", "int64")
+    oc = tir.Var("oc", "int64")
+    x0 = relax.Var("x", R.Tensor((2, 120, 28, 28), "float32"))
+    w0 = relax.Var("w", R.Tensor((128, 20, 3, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor((n, ic * 6, 28, 28), "float32"))
+    w1 = relax.Var("w", R.Tensor((oc * 6 + 4, ic * 6, 3, 3), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x0, w0, groups=6))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x1, w1, groups=6))
+
+
+def test_conv2d_non_positive_group():
+    x = relax.Var("x", R.Tensor((2, 128, 28, 28), "float32"))
+    w = relax.Var("w", R.Tensor((48, 16, 3, 3), "float32"))
+
+    with pytest.raises(TVMError):
+        relax.op.nn.conv2d(x, w, groups=0)
+    with pytest.raises(TVMError):
+        relax.op.nn.conv2d(x, w, groups=-2)
+
+
+def test_conv2d_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16"))
+    w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float64"))
+    w1 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float64"))
+    x2 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8"))
+    w2 = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8"))
+    x3 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int32"))
+    w3 = relax.Var("w", R.Tensor((4, 3, 3, 3), "int32"))
+
+    _check_inference(
+        bb, relax.op.nn.conv2d(x0, w0), relax.TensorStructInfo((2, 4, 26, 26), 
"float16")
+    )
+    _check_inference(
+        bb, relax.op.nn.conv2d(x1, w1), relax.TensorStructInfo((2, 4, 26, 26), 
"float64")
+    )
+    _check_inference(bb, relax.op.nn.conv2d(x2, w2), 
relax.TensorStructInfo((2, 4, 26, 26), "int8"))
+    _check_inference(
+        bb, relax.op.nn.conv2d(x3, w3), relax.TensorStructInfo((2, 4, 26, 26), 
"int32")
+    )
+
+
+def test_conv2d_infer_struct_info_mixed_precision():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16"))
+    w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8"))
+    w1 = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8"))
+    x2 = relax.Var("x", R.Tensor((2, 3, 28, 28)))
+    w2 = relax.Var("w", R.Tensor((4, 3, 3, 3)))
+
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x0, w0, out_dtype="float32"),
+        relax.TensorStructInfo((2, 4, 26, 26), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x1, w1, out_dtype="int32"),
+        relax.TensorStructInfo((2, 4, 26, 26), "int32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.conv2d(x2, w2, out_dtype="float32"),
+        relax.TensorStructInfo((2, 4, 26, 26), "float32"),
+    )
+
+
+def test_conv2d_unequal_input_channel():
+    bb = relax.BlockBuilder()
+    ic = tir.Var("ic", "int64")
+    x0 = relax.Var("x", R.Tensor([2, 3, 28, 28], "float32"))
+    w0 = relax.Var("w", R.Tensor([3, 4, 3, 3], "float32"))
+    x1 = relax.Var("x", R.Tensor([2, ic, 28, 28], "float32"))
+    w1 = relax.Var("w", R.Tensor([4, ic + 2, 3, 3], "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x0, w0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x1, w1))
+
+
+def test_conv2d_stride_padding_dilation_int64():
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
+    conv2d = relax.op.nn.conv2d(x, w, strides=(1, 1), padding=(1, 1), 
dilation=(1, 1))
+
+    assert conv2d.attrs.strides[0].dtype == "int64"
+    assert conv2d.attrs.strides[1].dtype == "int64"
+    assert conv2d.attrs.padding[0].dtype == "int64"
+    assert conv2d.attrs.padding[1].dtype == "int64"
+    assert conv2d.attrs.padding[2].dtype == "int64"
+    assert conv2d.attrs.padding[3].dtype == "int64"
+    assert conv2d.attrs.dilation[0].dtype == "int64"
+    assert conv2d.attrs.dilation[1].dtype == "int64"
+
+
+def test_conv2d_wrong_strides_padding_dilation_length():
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
+    with pytest.raises(TVMError):
+        relax.op.nn.conv2d(x, w, strides=(1, 2, 3))
+    with pytest.raises(TVMError):
+        relax.op.nn.conv2d(x, w, padding=(1, 2, 3))
+    with pytest.raises(TVMError):
+        relax.op.nn.conv2d(x, w, dilation=(1, 2, 3))
+
+
+def test_conv2d_infer_struct_info_wrong_layout_string():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    w = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x, w, data_layout="OIHW"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x, w, kernel_layout="NHWC"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x, w, out_layout="OHWI"))
+
+
+def test_conv2d_dtype_mismatch():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    w = relax.Var("w", R.Tensor((4, 3, 3, 3), "int8"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x, w))
+
+
+def test_conv2d_wrong_input_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32"))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=3))
+    w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
+    w1 = relax.Var("w", R.Tensor((4, 3, 6, 3, 3), "float32"))
+    w2 = relax.Var("w", R.Tensor("float32", ndim=6))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x0, w1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x0, w1, data_layout="NCHW16c"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x0, w2))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x1, w0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x2, w0))
+
+
+def test_conv2d_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    x1 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28)))
+    w0 = relax.Var("w", R.Tensor((4, 3, 3, 3), "float32"))
+    w1 = relax.Var("w", relax.FuncStructInfo([], R.Tensor((4, 3, 3, 3), 
"float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x0, w1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.conv2d(x1, w0))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_op_nn_pooling.py 
b/tests/python/relax/test_op_nn_pooling.py
new file mode 100644
index 0000000000..0eec5de21c
--- /dev/null
+++ b/tests/python/relax/test_op_nn_pooling.py
@@ -0,0 +1,429 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax, tir
+from tvm import TVMError
+from tvm.ir import Op
+from tvm.script import relax as R
+
+
+def test_op_correctness():
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    assert relax.op.nn.max_pool2d(x).op == Op.get("relax.nn.max_pool2d")
+    assert relax.op.nn.adaptive_avg_pool2d(x).op == 
Op.get("relax.nn.adaptive_avg_pool2d")
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
+    ret = bb.normalize(call)
+    tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_max_pool2d_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x3 = relax.Var("x", R.Tensor("float32"))
+    x4 = relax.Var("x", R.Tensor(ndim=4))
+    x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32"))
+
+    _check_inference(
+        bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), 
"float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(x0, pool_size=3),
+        relax.TensorStructInfo((2, 3, 30, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(x0, pool_size=(5, 3)),
+        relax.TensorStructInfo((2, 3, 28, 30), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.nn.max_pool2d(x0, padding=1), relax.TensorStructInfo((2, 
3, 34, 34), "float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(x0, padding=[1, 2]),
+        relax.TensorStructInfo((2, 3, 34, 36), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(x0, strides=2),
+        relax.TensorStructInfo((2, 3, 16, 16), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(x0, dilation=2),
+        relax.TensorStructInfo((2, 3, 32, 32), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(x1, layout="NHWC"),
+        relax.TensorStructInfo((2, 32, 32, 3), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(x0, out_layout="NHWC"),
+        relax.TensorStructInfo((2, 32, 32, 3), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(x6, layout="NCHW16c", out_layout="NHWC16c"),
+        relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.nn.max_pool2d(x2), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb, relax.op.nn.max_pool2d(x3), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(bb, relax.op.nn.max_pool2d(x4), 
relax.TensorStructInfo(dtype="", ndim=4))
+    _check_inference(bb, relax.op.nn.max_pool2d(x5), 
relax.TensorStructInfo(dtype="", ndim=4))
+
+
+def test_max_pool2d_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    c = tir.Var("c", "int64")
+    c16 = tir.Var("c16", "int64")
+    ih = tir.Var("ih", "int64")
+    iw = tir.Var("iw", "int64")
+    x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32"))
+    x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(
+            x0, pool_size=(3, 3), strides=(3, 3), padding=(2, 2), dilation=(2, 
2)
+        ),
+        relax.TensorStructInfo(
+            (
+                n,
+                c,
+                tvm.tir.floordiv(ih - 1, 3) + 1,
+                tvm.tir.floordiv(iw - 1, 3) + 1,
+            ),
+            "float32",
+        ),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(x1, layout="NCHW16c", out_layout="NHWC"),
+        relax.TensorStructInfo((n, ih, iw, c * 16), "float32"),
+    )
+
+
+def test_max_pool2d_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+    s2 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+
+    _check_inference(
+        bb, relax.op.nn.max_pool2d(x0), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(x1, layout="NCHW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=5),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(x2),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+
+
+def test_max_pool2d_infer_struct_info_ceil_mode():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(x, pool_size=3, strides=2, ceil_mode=True),
+        relax.TensorStructInfo((2, 3, 16, 16), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(x, pool_size=(5, 3), strides=2, ceil_mode=True),
+        relax.TensorStructInfo((2, 3, 15, 16), "float32"),
+    )
+
+
+def test_max_pool2d_infer_struct_info_ceil_mode_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    c = tir.Var("c", "int64")
+    ih = tir.Var("ih", "int64")
+    iw = tir.Var("iw", "int64")
+    x = relax.Var("x", R.Tensor((n, c, ih, iw), "float32"))
+
+    _check_inference(
+        bb,
+        relax.op.nn.max_pool2d(
+            x, pool_size=(3, 3), strides=(2, 2), padding=(1, 1), dilation=(2, 
2), ceil_mode=True
+        ),
+        relax.TensorStructInfo((n, c, tvm.tir.floordiv(ih, 2), 
tvm.tir.floordiv(iw, 2)), "float32"),
+    )
+
+
+def test_max_pool2d_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8"))
+    x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64"))
+    _check_inference(
+        bb, relax.op.nn.max_pool2d(x0), relax.TensorStructInfo((2, 3, 32, 32), 
"float16")
+    )
+    _check_inference(bb, relax.op.nn.max_pool2d(x1), 
relax.TensorStructInfo((2, 3, 32, 32), "int8"))
+    _check_inference(
+        bb, relax.op.nn.max_pool2d(x2), relax.TensorStructInfo((2, 3, 32, 32), 
"int64")
+    )
+
+
+def test_conv2d_stride_padding_dilation_int64():
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    max_pool2d = relax.op.nn.max_pool2d(x, (3, 3), strides=(1, 1), padding=(1, 
1), dilation=(1, 1))
+
+    assert max_pool2d.attrs.strides[0].dtype == "int64"
+    assert max_pool2d.attrs.strides[1].dtype == "int64"
+    assert max_pool2d.attrs.padding[0].dtype == "int64"
+    assert max_pool2d.attrs.padding[1].dtype == "int64"
+    assert max_pool2d.attrs.padding[2].dtype == "int64"
+    assert max_pool2d.attrs.padding[3].dtype == "int64"
+    assert max_pool2d.attrs.dilation[0].dtype == "int64"
+    assert max_pool2d.attrs.dilation[1].dtype == "int64"
+
+
+def test_max_pool2d_wrong_pool_size_strides_padding_dilation_length():
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    with pytest.raises(TVMError):
+        relax.op.nn.max_pool2d(x, pool_size=(1, 2, 3))
+    with pytest.raises(TVMError):
+        relax.op.nn.max_pool2d(x, strides=(1, 2, 3))
+    with pytest.raises(TVMError):
+        relax.op.nn.max_pool2d(x, padding=(1, 2, 3))
+    with pytest.raises(TVMError):
+        relax.op.nn.max_pool2d(x, dilation=(1, 2, 3))
+
+
+def test_max_pool2d_infer_struct_info_wrong_layout_string():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool2d(x, layout="OIHW"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool2d(x, out_layout="OHWI"))
+
+
+def test_max_pool2d_wrong_input_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool2d(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool2d(x1))
+
+
+def test_max_pool2d_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), 
"float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool2d(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.max_pool2d(x1))
+
+
+def test_adaptive_avg_pool2d_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x3 = relax.Var("x", R.Tensor("float32"))
+    x4 = relax.Var("x", R.Tensor(ndim=4))
+    x5 = relax.Var("x", R.Tensor())
+    x6 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32"))
+
+    _check_inference(
+        bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((2, 3, 
32, 32), "float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.adaptive_avg_pool2d(x0, output_size=30),
+        relax.TensorStructInfo((2, 3, 30, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.adaptive_avg_pool2d(x0, output_size=(28, 30)),
+        relax.TensorStructInfo((2, 3, 28, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.adaptive_avg_pool2d(x1, layout="NHWC"),
+        relax.TensorStructInfo((2, 32, 32, 3), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.adaptive_avg_pool2d(x0, out_layout="NHWC"),
+        relax.TensorStructInfo((2, 32, 32, 3), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.adaptive_avg_pool2d(x6, layout="NCHW16c", 
out_layout="NHWC16c"),
+        relax.TensorStructInfo((2, 32, 32, 4, 16), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.nn.adaptive_avg_pool2d(x2), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb, relax.op.nn.adaptive_avg_pool2d(x3), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb, relax.op.nn.adaptive_avg_pool2d(x4), 
relax.TensorStructInfo(dtype="", ndim=4)
+    )
+    _check_inference(
+        bb, relax.op.nn.adaptive_avg_pool2d(x5), 
relax.TensorStructInfo(dtype="", ndim=4)
+    )
+
+
+def test_adaptive_avg_pool2d_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    c = tir.Var("c", "int64")
+    c16 = tir.Var("c16", "int64")
+    ih = tir.Var("ih", "int64")
+    iw = tir.Var("iw", "int64")
+    x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32"))
+    x1 = relax.Var("x", R.Tensor((n, c, ih, iw, c16), "float32"))
+
+    _check_inference(
+        bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((n, c, 
ih, iw), "float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.adaptive_avg_pool2d(x0, output_size=256),
+        relax.TensorStructInfo((n, c, 256, 256), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.adaptive_avg_pool2d(x0, output_size=(256, 128)),
+        relax.TensorStructInfo((n, c, 256, 128), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.adaptive_avg_pool2d(x1, layout="NCHW16c", 
out_layout="NHWC"),
+        relax.TensorStructInfo((n, ih, iw, c * 16), "float32"),
+    )
+
+
+def test_adaptive_avg_pool2d_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+    s2 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+
+    _check_inference(bb, relax.op.nn.adaptive_avg_pool2d(x0), 
relax.TensorStructInfo(s0, "float32"))
+    _check_inference(
+        bb,
+        relax.op.nn.adaptive_avg_pool2d(x0, output_size=32),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.adaptive_avg_pool2d(x1, layout="NCHW16c"),
+        relax.TensorStructInfo(s1, "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.adaptive_avg_pool2d(x0, out_layout="NCHW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=5),
+    )
+    _check_inference(
+        bb,
+        relax.op.nn.adaptive_avg_pool2d(x2, out_layout="NCHW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=5),
+    )
+
+
+def test_adaptive_avg_pool2d_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8"))
+    x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64"))
+    _check_inference(
+        bb, relax.op.nn.adaptive_avg_pool2d(x0), relax.TensorStructInfo((2, 3, 
32, 32), "float16")
+    )
+    _check_inference(
+        bb, relax.op.nn.adaptive_avg_pool2d(x1), relax.TensorStructInfo((2, 3, 
32, 32), "int8")
+    )
+    _check_inference(
+        bb, relax.op.nn.adaptive_avg_pool2d(x2), relax.TensorStructInfo((2, 3, 
32, 32), "int64")
+    )
+
+
+def test_adaptive_avg_pool2d_wrong_output_size_ndim():
+    x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    with pytest.raises(TVMError):
+        relax.op.nn.adaptive_avg_pool2d(x, (32, 32, 32))
+
+
+def test_adaptive_avg_pool2d_infer_struct_info_wrong_layout_string():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.adaptive_avg_pool2d(x, layout="OIHW"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.adaptive_avg_pool2d(x, out_layout="OHWI"))
+
+
+def test_adaptive_avg_pool2d_wrong_input_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 28, 28, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=3))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.adaptive_avg_pool2d(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.adaptive_avg_pool2d(x1))
+
+
+def test_adaptive_avg_pool2d_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), 
"float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.adaptive_avg_pool2d(x0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.adaptive_avg_pool2d(x1))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_nn.py 
b/tests/python/relax/test_tvmscript_parser_op_nn.py
new file mode 100644
index 0000000000..4e52bccb86
--- /dev/null
+++ b/tests/python/relax/test_tvmscript_parser_op_nn.py
@@ -0,0 +1,193 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Union
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import IRModule, relax
+from tvm.script import relax as R
+
+
+def _check(
+    parsed: Union[relax.Function, IRModule],
+    expect: Optional[Union[relax.Function, IRModule]],
+):
+    test = parsed.script(show_meta=True)
+    roundtrip_mod = tvm.script.from_source(test)
+    tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
+    if expect:
+        tvm.ir.assert_structural_equal(parsed, expect)
+
+
+def test_conv2d():
+    @R.function
+    def foo(
+        x: R.Tensor((2, 3, 228, 228), "float32"), w: R.Tensor((16, 3, 5, 5), 
"float32")
+    ) -> R.Tensor((2, 16, 224, 224), "float16"):
+        gv: R.Tensor((2, 16, 224, 224), "float16") = R.nn.conv2d(x, w, 
out_dtype="float16")
+        return gv
+
+    x = relax.Var("x", R.Tensor([2, 3, 228, 228], "float32"))
+    w = relax.Var("w", R.Tensor([16, 3, 5, 5], "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x, w]):
+        gv = bb.emit(relax.op.nn.conv2d(x, w, out_dtype="float16"))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_max_pool2d():
+    @R.function
+    def foo(
+        x: R.Tensor((1, 1, 32, 32), dtype="float32")
+    ) -> R.Tensor((1, 1, 30, 30), dtype="float32"):
+        gv: R.Tensor((1, 1, 30, 30), dtype="float32") = R.nn.max_pool2d(x, 
pool_size=(3,))
+        return gv
+
+    x = relax.Var("x", R.Tensor([1, 1, 32, 32], "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.nn.max_pool2d(x, pool_size=(3,)))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_adaptive_avg_pool2d():
+    @R.function
+    def foo(x: R.Tensor((2, 64, 8, 9), "float32")) -> R.Tensor((2, 64, 7, 7), 
"float32"):
+        gv: R.Tensor((2, 64, 7, 7), "float32") = R.nn.adaptive_avg_pool2d(x, 
output_size=(7, 7))
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 64, 8, 9), dtype="float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.nn.adaptive_avg_pool2d(x, output_size=(7, 7)))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_gelu():
+    @R.function
+    def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
+        gv: R.Tensor((2, 3), "float32") = R.nn.gelu(x)
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 3), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.nn.gelu(x))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_softmax():
+    @R.function
+    def foo(x: R.Tensor((2, 3), "float32")) -> R.Tensor((2, 3), "float32"):
+        gv: R.Tensor((2, 3), "float32") = R.nn.softmax(x)
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 3), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.nn.softmax(x))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_batch_norm():
+    @R.function
+    def foo(
+        x: R.Tensor((2, 4, 3, 3), dtype="float32"),
+        gamma: R.Tensor((4,), dtype="float32"),
+        beta: R.Tensor((4,), dtype="float32"),
+        moving_mean: R.Tensor((4,), dtype="float32"),
+        moving_var: R.Tensor((4,), dtype="float32"),
+    ) -> R.Tuple(
+        R.Tensor((2, 4, 3, 3), dtype="float32"),
+        R.Tensor((4,), dtype="float32"),
+        R.Tensor((4,), dtype="float32"),
+    ):
+        gv: R.Tuple(
+            R.Tensor((2, 4, 3, 3), dtype="float32"),
+            R.Tensor((4,), dtype="float32"),
+            R.Tensor((4,), dtype="float32"),
+        ) = R.nn.batch_norm(x, gamma, beta, moving_mean, moving_var, axis=1)
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 4, 3, 3), "float32"))
+    gamma = relax.Var("gamma", R.Tensor((4,), "float32"))
+    beta = relax.Var("beta", R.Tensor((4,), "float32"))
+    moving_mean = relax.Var("moving_mean", R.Tensor((4,), "float32"))
+    moving_var = relax.Var("moving_var", R.Tensor((4,), "float32"))
+
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x, gamma, beta, moving_mean, moving_var]):
+        gv = bb.emit(relax.op.nn.batch_norm(x, gamma, beta, moving_mean, 
moving_var, axis=1))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_layer_norm():
+    @R.function
+    def foo(
+        x: R.Tensor((2, 3, 4, 5), "float32"),
+        gamma: R.Tensor((4, 5), "float32"),
+        beta: R.Tensor((4, 5), "float32"),
+    ) -> R.Tensor((2, 3, 4, 5), "float32"):
+        gv: R.Tensor((2, 3, 4, 5), "float32") = R.nn.layer_norm(x, gamma, 
beta, axes=[-2, -1])
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    gamma = relax.Var("gamma", R.Tensor((4, 5), "float32"))
+    beta = relax.Var("beta", R.Tensor((4, 5), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x, gamma, beta]):
+        gv = bb.emit(relax.op.nn.layer_norm(x, gamma, beta, axes=[-2, -1]))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_dropout():
+    @R.function
+    def foo(
+        x: R.Tensor((2, 3), "float32")
+    ) -> R.Tuple(R.Tensor((2, 3), "float32"), R.Tensor((2, 3), "float32")):
+        gv: R.Tuple(R.Tensor((2, 3), "float32"), R.Tensor((2, 3), "float32")) 
= R.nn.dropout(
+            x, rate=0.5
+        )
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 3), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.nn.dropout(x))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to