This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new af82d97043 [Relax][Frontend][Onnx] support MaxPool1/2/3D and 
AveragePool1/2/3D (#16681)
af82d97043 is described below

commit af82d970436d25018d216bd02dd70fae9d5e6e83
Author: chengven027-intellif <[email protected]>
AuthorDate: Sun Mar 10 20:47:24 2024 +0800

    [Relax][Frontend][Onnx] support MaxPool1/2/3D and AveragePool1/2/3D (#16681)
    
    support MaxPool1/2/3D and AveragePool1/2/3D
    
    Co-authored-by: cheng wen <chengven027-intellif>
---
 include/tvm/relax/attrs/nn.h                       |  80 +++++
 python/tvm/relax/frontend/onnx/onnx_frontend.py    | 102 ++++--
 python/tvm/relax/op/_op_gradient.py                |   2 +
 python/tvm/relax/op/grad/grad.py                   |  24 +-
 python/tvm/relax/op/nn/__init__.py                 |   4 +
 python/tvm/relax/op/nn/nn.py                       | 354 ++++++++++++++++++++-
 python/tvm/relax/transform/legalize_ops/grad.py    |   2 +
 python/tvm/relax/transform/legalize_ops/nn.py      |  95 ++++++
 python/tvm/topi/nn/pooling.py                      |   7 +-
 src/relax/op/nn/pooling.cc                         | 286 ++++++++++++++++-
 src/relax/op/nn/pooling.h                          |   6 +-
 src/relax/op/tensor/grad.cc                        |   8 +-
 src/relax/op/tensor/grad.h                         |   6 +-
 tests/python/relax/test_frontend_onnx.py           | 251 +++++++++++----
 tests/python/relax/test_op_gradient_numeric.py     |  10 +-
 .../relax/test_transform_legalize_ops_grad.py      |   3 +-
 .../python/relax/test_transform_legalize_ops_nn.py |   7 +-
 17 files changed, 1123 insertions(+), 124 deletions(-)

diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index dd63a70bc4..0bb2dcaab5 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -254,6 +254,43 @@ struct Conv2DTransposeAttrs : public 
tvm::AttrsNode<Conv2DTransposeAttrs> {
   }
 };  // struct Conv2DTransposeAttrs
 
+/*! \brief Attributes used in max_pool1d and avg_pool1d operator */
+struct Pool1DAttrs : public tvm::AttrsNode<Pool1DAttrs> {
+  Array<IntImm> pool_size;
+  Array<IntImm> strides;
+  Array<IntImm> padding;
+  Array<IntImm> dilation;
+  bool ceil_mode;
+  bool count_include_pad;
+  String layout;
+  String out_layout;
+
+  TVM_DECLARE_ATTRS(Pool1DAttrs, "relax.attrs.Pool1DAttrs") {
+    TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
+    TVM_ATTR_FIELD(strides).describe("Specifies the strides of the 
convolution.");
+    TVM_ATTR_FIELD(dilation).describe("Specifies the dilation of the 
convolution.");
+    TVM_ATTR_FIELD(padding).describe(
+        "If padding is non-zero, then the input is implicitly zero-padded"
+        "Padding support both symmetric and asymmetric as"
+        "one int : same padding used on all sides"
+        "two int : padding width in the order of (left, right)");
+    TVM_ATTR_FIELD(ceil_mode).describe(
+        "A boolean indicating if use ceil or floor to compute the output 
shape. By using ceil, "
+        "every element in the input tensor will be covered by a sliding 
window.");
+    TVM_ATTR_FIELD(count_include_pad)
+        .describe("When true, will include padding to compute the average");
+    TVM_ATTR_FIELD(layout).set_default("NCW").describe(
+        "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
+        "'N', 'C', 'W' stands for batch, channel, and width"
+        "dimensions respectively. Pooling is applied on the 'W' dimensions.");
+    TVM_ATTR_FIELD(out_layout)
+        .describe(
+            "Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
+            "'N', 'C', 'W' stands for batch, channel, and width"
+            "dimensions respectively. Pooling is applied on the 'W' 
dimensions.");
+  }
+};  // struct Pool1dAttrs
+
 /*! \brief Attributes used in max_pool2d and avg_pool2d operator */
 struct Pool2DAttrs : public tvm::AttrsNode<Pool2DAttrs> {
   Array<IntImm> pool_size;
@@ -261,6 +298,7 @@ struct Pool2DAttrs : public tvm::AttrsNode<Pool2DAttrs> {
   Array<IntImm> padding;
   Array<IntImm> dilation;
   bool ceil_mode;
+  bool count_include_pad;
   String layout;
   String out_layout;
 
@@ -277,6 +315,8 @@ struct Pool2DAttrs : public tvm::AttrsNode<Pool2DAttrs> {
     TVM_ATTR_FIELD(ceil_mode).describe(
         "A boolean indicating if use ceil or floor to compute the output 
shape. By using ceil, "
         "every element in the input tensor will be covered by a sliding 
window.");
+    TVM_ATTR_FIELD(count_include_pad)
+        .describe("When true, will include padding to compute the average");
     TVM_ATTR_FIELD(layout).describe(
         "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
         "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
@@ -291,6 +331,46 @@ struct Pool2DAttrs : public tvm::AttrsNode<Pool2DAttrs> {
   }
 };  // struct Pool2dAttrs
 
+/*! \brief Attributes used in max_pool3d and avg_pool3d operator */
+struct Pool3DAttrs : public tvm::AttrsNode<Pool3DAttrs> {
+  Array<IntImm> pool_size;
+  Array<IntImm> strides;
+  Array<IntImm> padding;
+  Array<IntImm> dilation;
+  bool ceil_mode;
+  bool count_include_pad;
+  String layout;
+  String out_layout;
+
+  TVM_DECLARE_ATTRS(Pool3DAttrs, "relax.attrs.Pool3DAttrs") {
+    TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
+    TVM_ATTR_FIELD(strides).describe("Specifies the strides of the 
convolution.");
+    TVM_ATTR_FIELD(dilation).describe("Specifies the dilation of the 
convolution.");
+    TVM_ATTR_FIELD(padding).describe(
+        "If padding is non-zero, then the input is implicitly zero-padded"
+        "Padding support both symmetric and asymmetric as"
+        "one int : same padding used on all sides"
+        "three int : back, bottom, right will use same padding as front, top, 
left"
+        "four int : padding width in the order of (front, top, left, back, 
bottom, right)");
+    TVM_ATTR_FIELD(ceil_mode).describe(
+        "A boolean indicating if use ceil or floor to compute the output 
shape. By using ceil, "
+        "every element in the input tensor will be covered by a sliding 
window.");
+    TVM_ATTR_FIELD(count_include_pad)
+        .describe("When true, will include padding to compute the average");
+    TVM_ATTR_FIELD(layout).describe(
+        "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
+        "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and 
width"
+        "dimensions respectively. Pooling is applied on the 'D', 'H' and"
+        "'W' dimensions.");
+    TVM_ATTR_FIELD(out_layout)
+        .describe(
+            "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
+            "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, 
and width"
+            "dimensions respectively. Pooling is applied on the 'D', 'H' and"
+            "'W' dimensions.");
+  }
+};  // struct Pool3dAttrs
+
 /*! \brief Attributes for 2d adaptive pool operator */
 struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
   Optional<Array<IntImm>> output_size;
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 092e73baa1..a047e8701c 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1438,21 +1438,40 @@ class BatchNormalization(OnnxOpConverter):
         )
 
 
-class MaxPool(OnnxOpConverter):
-    """Converts an onnx MaxPool node into an equivalent Relax expression."""
+class Pool(OnnxOpConverter):
+    """A helper class for pool op converters."""
+
+    name = ""
 
     @classmethod
-    def _impl_v12(cls, bb, inputs, attr, params):
+    def get_pad_pair(cls, input1d, kernel1d, stride1d, mode):
+        """infer pad size"""
+        if input1d % stride1d == 0:
+            pad = max(kernel1d - stride1d, 0)
+        else:
+            pad = max(kernel1d - (input1d % stride1d), 0)
+        pad_before = pad // 2
+        pad_after = pad - pad_before
+        if "LOWER" in mode:
+            return [pad_after, pad_before]
+        return [pad_before, pad_after]
+
+    @classmethod
+    def _impl_v1(cls, bb, inputs, attr, params):
         # Unpack inputs and attributes.
         data = inputs[0]
+        input_shape = data.struct_info.shape
+        ndim = len(input_shape)
+
         auto_pad = attr.get("auto_pad", b"NOTSET").decode("utf-8")
         ceil_mode = attr.get("ceil_mode", 0)
-        dilations = attr.get("dilations", [1, 1])
+        dilations = attr.get("dilations", [1] * (ndim - 2))
         kernel_shape = attr.get("kernel_shape")
         pads = attr.get("pads", 0)
-        strides = attr.get("strides", [1, 1])
+        strides = attr.get("strides", [1] * (ndim - 2))
+
+        assert len(kernel_shape) in [1, 2, 3], "Currently only 1D/2D/3D/ 
pooling is supported."
 
-        assert len(kernel_shape) == 2, "Currently only 2D pooling is 
supported."
         assert auto_pad in [
             "NOTSET",
             "SAME_UPPER",
@@ -1461,34 +1480,40 @@ class MaxPool(OnnxOpConverter):
         ], f"Value {auto_pad} in attribute auto_pad is invalid."
 
         if auto_pad in ("SAME_UPPER", "SAME_LOWER"):
-            input_spatial_shape = cls._get_input_spatial_shape(data)
-            output_spatial_shape = [0 for _ in input_spatial_shape]
-
-            pads = _np.array([(0, 0) for _ in range(len(kernel_shape))])
+            pads = []
+            if cls.name == "avg_pool":
+                for axis in range(len(input_shape) - 2):
+                    axis_shape = input_shape[2 + axis]
+                    stride = strides[axis]
+                    kernel = kernel_shape[axis]
+                    pad = cls.get_pad_pair(axis_shape, kernel, stride, 
auto_pad)
+                    pads.append(pad)
+            else:
+                input_spatial_shape = cls._get_input_spatial_shape(data)
+                output_spatial_shape = [0 for _ in input_spatial_shape]
+
+                for i, _ in enumerate(input_spatial_shape):
+                    if auto_pad == "SAME_UPPER":
+                        output_spatial_shape[i] = 
int(_np.ceil(input_spatial_shape[i] / strides[i]))
+                    else:
+                        output_spatial_shape[i] = int(
+                            _np.floor(input_spatial_shape[i] / strides[i])
+                        )
+                    pad_i = (
+                        (output_spatial_shape[i] - 1) * strides[i]
+                        + ((kernel_shape[i] - 1) * dilations[i] + 1)
+                        - input_spatial_shape[i]
+                    )
 
-            for i, _ in enumerate(input_spatial_shape):
-                if auto_pad == "SAME_UPPER":
-                    output_spatial_shape[i] = 
int(_np.ceil(input_spatial_shape[i] / strides[i]))
-                else:
-                    output_spatial_shape[i] = 
int(_np.floor(input_spatial_shape[i] / strides[i]))
-                pad_i = (
-                    (output_spatial_shape[i] - 1) * strides[i]
-                    + ((kernel_shape[i] - 1) * dilations[i] + 1)
-                    - input_spatial_shape[i]
-                )
-                if auto_pad == "SAME_UPPER":
-                    pads[i, 0] = pad_i // 2
-                    pads[i, 1] = pad_i - pads[i, 0]
-                else:
-                    pads[i, 1] = pad_i // 2
-                    pads[i, 0] = pad_i - pads[i, 1]
+                    if auto_pad == "SAME_UPPER":
+                        pads.append([pad_i // 2, pad_i - pad_i // 2])
+                    else:
+                        pads.append([pad_i - pad_i // 2, pad_i // 2])
 
-            # TODO(agladyshev): for now we support only 2D kernel
-            # (top, left, bottom, right)
-            flatten_pads = [pads[0][0], pads[1][0], pads[0][1], pads[1][1]]
-            pads = tuple(flatten_pads)
+            pads = tuple([val for pair in zip(*pads) for val in pair])
 
-        return relax.op.nn.max_pool2d(data, kernel_shape, strides, pads, 
dilations, ceil_mode)
+        op = getattr(relax.op.nn, cls.name + str(len(kernel_shape)) + "d")
+        return op(data, kernel_shape, strides, pads, dilations, ceil_mode)
 
     @classmethod
     def _get_input_spatial_shape(cls, tensor):
@@ -1496,6 +1521,18 @@ class MaxPool(OnnxOpConverter):
         return _np.array([int(d) for d in tensor.struct_info.shape], 
dtype="int64")[2:]
 
 
+class MaxPool(Pool):
+    """Converts an onnx MaxPool node into an equivalent Relax expression."""
+
+    name = "max_pool"
+
+
+class AveragePool(Pool):
+    """Converts an onnx MaxPool node into an equivalent Relax expression."""
+
+    name = "avg_pool"
+
+
 class GlobalAveragePool(OnnxOpConverter):
     """Converts an onnx GlobalAveragePool node into an equivalent Relax 
expression."""
 
@@ -1922,9 +1959,10 @@ def _get_convert_map():
         "Split": Split,
         "Tile": Tile,
         "BatchNormalization": BatchNormalization,
+        "MaxPool": MaxPool,
+        "AveragePool": AveragePool,
         "GlobalAveragePool": GlobalAveragePool,
         "Flatten": Flatten,
-        "MaxPool": MaxPool,
         "Identity": Identity,
         "Resize": Resize,
         "Einsum": Einsum,
diff --git a/python/tvm/relax/op/_op_gradient.py 
b/python/tvm/relax/op/_op_gradient.py
index 1b0ebfd5e4..6878f97331 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -1279,6 +1279,7 @@ out_layout)`
             orig_call.attrs.padding,
             orig_call.attrs.dilation,
             orig_call.attrs.ceil_mode,
+            orig_call.attrs.count_include_pad,
             orig_call.attrs.layout,
             orig_call.attrs.out_layout,
         )
@@ -1310,6 +1311,7 @@ out_layout)`
             orig_call.attrs.padding,
             orig_call.attrs.dilation,
             orig_call.attrs.ceil_mode,
+            orig_call.attrs.count_include_pad,
             orig_call.attrs.layout,
             orig_call.attrs.out_layout,
         )
diff --git a/python/tvm/relax/op/grad/grad.py b/python/tvm/relax/op/grad/grad.py
index 2218db2232..304ad9cc2f 100644
--- a/python/tvm/relax/op/grad/grad.py
+++ b/python/tvm/relax/op/grad/grad.py
@@ -130,6 +130,7 @@ def max_pool2d_backward(
     padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
     dilation: Tuple[int, int] = (1, 1),
     ceil_mode: bool = False,
+    count_include_pad: bool = False,
     layout: str = "NCHW",
     out_layout: Optional[str] = None,
 ) -> Expr:
@@ -147,7 +148,16 @@ def max_pool2d_backward(
       The gradient w.r.t. data.
     """
     return _ffi_api.max_pool2d_backward(  # type: ignore
-        output_grad, data, pool_size, strides, padding, dilation, ceil_mode, 
layout, out_layout
+        output_grad,
+        data,
+        pool_size,
+        strides,
+        padding,
+        dilation,
+        ceil_mode,
+        count_include_pad,
+        layout,
+        out_layout,
     )
 
 
@@ -159,6 +169,7 @@ def avg_pool2d_backward(
     padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
     dilation: Tuple[int, int] = (1, 1),
     ceil_mode: bool = False,
+    count_include_pad: bool = False,
     layout: str = "NCHW",
     out_layout: Optional[str] = None,
 ) -> Expr:
@@ -176,7 +187,16 @@ def avg_pool2d_backward(
       The gradient w.r.t. data.
     """
     return _ffi_api.avg_pool2d_backward(  # type: ignore
-        output_grad, data, pool_size, strides, padding, dilation, ceil_mode, 
layout, out_layout
+        output_grad,
+        data,
+        pool_size,
+        strides,
+        padding,
+        dilation,
+        ceil_mode,
+        count_include_pad,
+        layout,
+        out_layout,
     )
 
 
diff --git a/python/tvm/relax/op/nn/__init__.py 
b/python/tvm/relax/op/nn/__init__.py
index d90b207314..cb90a86883 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -19,7 +19,9 @@ from .nn import (
     adaptive_avg_pool2d,
     attention,
     attention_var_len,
+    avg_pool1d,
     avg_pool2d,
+    avg_pool3d,
     batch_norm,
     conv1d,
     conv1d_transpose,
@@ -34,7 +36,9 @@ from .nn import (
     layer_norm,
     leakyrelu,
     log_softmax,
+    max_pool1d,
     max_pool2d,
+    max_pool3d,
     nll_loss,
     pad,
     relu,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 151c43af55..26ba894e84 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -542,6 +542,87 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"):
     return _ffi_api.pad(data, pad_width, pad_value, pad_mode)
 
 
+def max_pool1d(
+    data: Expr,
+    pool_size: Union[int, Tuple[int, int]] = (1,),
+    strides: Union[int, Tuple[int, int]] = (1,),
+    padding: Union[int, Tuple[int, ...]] = (0, 0),
+    dilation: Union[int, Tuple[int, int]] = (1,),
+    ceil_mode: bool = False,
+    count_include_pad: bool = False,
+    layout: str = "NCW",
+    out_layout: Optional[str] = None,
+) -> Expr:
+    r"""1D maximum pooling operator.
+
+    This operator takes data as input and does 1D max value calculation
+    with in pool_size sized window by striding defined by stride.
+
+    IIn the default case, where the data_layout is `NCW`
+    a data Tensor with shape `(batch_size, channels, width)`,
+    to produce an output Tensor.
+
+    The ceil_mode is used to take ceil or floor while computing out shape.
+    count_include_pad indicates including or excluding padded input values in 
computation.
+    This operator accepts data layout specification.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    pool_size : Union[int, Tuple[int, int]]
+        The size of window for pooling. It is required to have length either 1.
+
+    strides : Union[int, Tuple[int, int]]
+        The strides of pooling. It is required to have length either 1.
+
+    padding : Union[int, Tuple[int, ...]]
+        The padding for pooling. It is required to have length either 1 or 2.
+
+    dilation : Union[int, Tuple[int, int]]
+        The dilation of pooling. It is required to have length either 1.
+
+    ceil_mode : bool
+        A boolean indicating if use ceil or floor to compute the output shape.
+        By using ceil, every element in the input tensor will be covered by a 
sliding window.
+
+    count_include_pad : bool, optional
+        To include padding to compute the average.
+
+    layout : str
+        Layout of the input.
+
+    out_layout : Optional[str]
+        Layout of the output. If not specified, it is the same as data_layout
+
+    Returns
+    -------
+    result : Expr
+        The computed result.
+    """
+    if isinstance(pool_size, int):
+        pool_size = (pool_size,)
+    if isinstance(strides, int):
+        strides = (strides,)
+    if isinstance(dilation, int):
+        dilation = (dilation,)
+    if isinstance(padding, int):
+        padding = (padding, padding)
+
+    return _ffi_api.max_pool1d(  # type: ignore
+        data,
+        pool_size,
+        strides,
+        padding,
+        dilation,
+        ceil_mode,
+        count_include_pad,
+        layout,
+        out_layout,
+    )
+
+
 def max_pool2d(
     data: Expr,
     pool_size: Union[int, Tuple[int, int]] = (1, 1),
@@ -549,6 +630,7 @@ def max_pool2d(
     padding: Union[int, Tuple[int, ...]] = (0, 0),
     dilation: Union[int, Tuple[int, int]] = (1, 1),
     ceil_mode: bool = False,
+    count_include_pad: bool = False,
     layout: str = "NCHW",
     out_layout: Optional[str] = None,
 ) -> Expr:
@@ -593,6 +675,9 @@ def max_pool2d(
         A boolean indicating if use ceil or floor to compute the output shape.
         By using ceil, every element in the input tensor will be covered by a 
sliding window.
 
+    count_include_pad : bool, optional
+        To include padding to compute the average.
+
     layout : str
         Layout of the input.
 
@@ -614,7 +699,177 @@ def max_pool2d(
         padding = (padding, padding, padding, padding)
 
     return _ffi_api.max_pool2d(  # type: ignore
-        data, pool_size, strides, padding, dilation, ceil_mode, layout, 
out_layout
+        data,
+        pool_size,
+        strides,
+        padding,
+        dilation,
+        ceil_mode,
+        count_include_pad,
+        layout,
+        out_layout,
+    )
+
+
+def max_pool3d(
+    data: Expr,
+    pool_size: Union[int, Tuple[int, int]] = (1, 1, 1),
+    strides: Union[int, Tuple[int, int]] = (1, 1, 1),
+    padding: Union[int, Tuple[int, ...]] = (0, 0, 0),
+    dilation: Union[int, Tuple[int, int]] = (1, 1, 1),
+    ceil_mode: bool = False,
+    count_include_pad: bool = False,
+    layout: str = "NCDHW",
+    out_layout: Optional[str] = None,
+) -> Expr:
+    r"""3D maximum pooling operator.
+
+    This operator takes data as input and does 3D max value calculation
+    with in pool_size sized window by striding defined by stride.
+
+
+    In the default case, where the data_layout is `NCDHW`
+    a data Tensor with shape `(batch_size, channels, depth, height, width)`,
+    to produce an output Tensor.
+
+    The ceil_mode is used to take ceil or floor while computing out shape.
+    count_include_pad indicates including or excluding padded input values in 
computation.
+    This operator accepts data layout specification.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    pool_size : Union[int, Tuple[int, int]]
+        The size of window for pooling. It is required to have length either 1 
or 3.
+
+    strides : Union[int, Tuple[int, int]]
+        The strides of pooling. It is required to have length either 1 or 3.
+
+    padding : Union[int, Tuple[int, ...]]
+        The padding for pooling. It is required to have length either 1, 3 or 
6.
+
+    dilation : Union[int, Tuple[int, int]]
+        The dilation of pooling. It is required to have length either 1 or 3.
+
+    ceil_mode : bool
+        A boolean indicating if use ceil or floor to compute the output shape.
+        By using ceil, every element in the input tensor will be covered by a 
sliding window.
+
+    count_include_pad : bool, optional
+        To include padding to compute the average.
+
+    layout : str
+        Layout of the input.
+
+    out_layout : Optional[str]
+        Layout of the output. If not specified, it is the same as data_layout
+
+    Returns
+    -------
+    result : Expr
+        The computed result.
+    """
+    if isinstance(pool_size, int):
+        pool_size = (pool_size, pool_size, pool_size)
+    if isinstance(strides, int):
+        strides = (strides, strides, strides)
+    if isinstance(dilation, int):
+        dilation = (dilation, dilation, dilation)
+    if isinstance(padding, int):
+        padding = (padding, padding, padding, padding, padding, padding)
+
+    return _ffi_api.max_pool3d(  # type: ignore
+        data,
+        pool_size,
+        strides,
+        padding,
+        dilation,
+        ceil_mode,
+        count_include_pad,
+        layout,
+        out_layout,
+    )
+
+
+def avg_pool1d(
+    data: Expr,
+    pool_size: Union[int, Tuple[int, int]] = (1,),
+    strides: Union[int, Tuple[int, int]] = (1,),
+    padding: Union[int, Tuple[int, ...]] = (0, 0),
+    dilation: Union[int, Tuple[int, int]] = (1,),
+    ceil_mode: bool = False,
+    count_include_pad: bool = False,
+    layout: str = "NCW",
+    out_layout: Optional[str] = None,
+) -> Expr:
+    r"""1D average pooling operator.
+
+    This operator takes data as input and does 1D average value calculation
+    with in pool_size sized window by striding defined by stride
+
+    In the default case, where the data_layout is `NCW`
+    a data Tensor with shape `(batch_size, channels, width)`,
+    to produce an output Tensor.
+
+    The ceil_mode is used to take ceil or floor while computing out shape.
+    count_include_pad indicates including or excluding padded input values in 
computation.
+    This operator accepts data layout specification.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    pool_size : Union[int, Tuple[int]]
+        The size of window for pooling. It is required to have length is 1.
+
+    strides : Union[int, Tuple[int]]
+        The strides of pooling. It is required to have length is 1.
+
+    padding : Union[int, Tuple[int, int]]
+        The padding for pooling. It is required to have length either 1 or 2.
+
+    dilation : Union[int, Tuple[int]]
+        The dilation of pooling. It is required to have length is 1.
+
+    ceil_mode : bool
+        A boolean indicating if use ceil or floor to compute the output shape.
+        By using ceil, every element in the input tensor will be covered by a 
sliding window.
+
+    count_include_pad : bool, optional
+        To include padding to compute the average.
+
+    layout : str
+        Layout of the input.
+
+    out_layout : Optional[str]
+        Layout of the output. If not specified, it is the same as data_layout
+
+    Returns
+    -------
+    result : Expr
+        The computed result.
+    """
+    if isinstance(pool_size, int):
+        pool_size = (pool_size,)
+    if isinstance(strides, int):
+        strides = (strides,)
+    if isinstance(dilation, int):
+        dilation = (dilation,)
+    if isinstance(padding, int):
+        padding = (padding, padding)
+    return _ffi_api.avg_pool1d(  # type: ignore
+        data,
+        pool_size,
+        strides,
+        padding,
+        dilation,
+        ceil_mode,
+        count_include_pad,
+        layout,
+        out_layout,
     )
 
 
@@ -625,6 +880,7 @@ def avg_pool2d(
     padding: Union[int, Tuple[int, ...]] = (0, 0),
     dilation: Union[int, Tuple[int, int]] = (1, 1),
     ceil_mode: bool = False,
+    count_include_pad: bool = False,
     layout: str = "NCHW",
     out_layout: Optional[str] = None,
 ) -> Expr:
@@ -670,6 +926,9 @@ def avg_pool2d(
         A boolean indicating if use ceil or floor to compute the output shape.
         By using ceil, every element in the input tensor will be covered by a 
sliding window.
 
+    count_include_pad : bool, optional
+        To include padding to compute the average.
+
     layout : str
         Layout of the input.
 
@@ -689,9 +948,98 @@ def avg_pool2d(
         dilation = (dilation, dilation)
     if isinstance(padding, int):
         padding = (padding, padding, padding, padding)
-
     return _ffi_api.avg_pool2d(  # type: ignore
-        data, pool_size, strides, padding, dilation, ceil_mode, layout, 
out_layout
+        data,
+        pool_size,
+        strides,
+        padding,
+        dilation,
+        ceil_mode,
+        count_include_pad,
+        layout,
+        out_layout,
+    )
+
+
+def avg_pool3d(
+    data: Expr,
+    pool_size: Union[int, Tuple[int, int]] = (1, 1, 1),
+    strides: Union[int, Tuple[int, int]] = (1, 1, 1),
+    padding: Union[int, Tuple[int, ...]] = (0, 0, 0),
+    dilation: Union[int, Tuple[int, int]] = (1, 1, 1),
+    ceil_mode: bool = False,
+    count_include_pad: bool = False,
+    layout: str = "NCDHW",
+    out_layout: Optional[str] = None,
+) -> Expr:
+    r"""2D average pooling operator.
+
+    This operator takes data as input and does 3D average value calculation
+    with in pool_size sized window by striding defined by stride
+
+
+    In the default case, where the data_layout is `NCDHW`
+    a data Tensor with shape `(batch_size, channels, depth, height, width)`,
+    to produce an output Tensor.
+
+    The ceil_mode is used to take ceil or floor while computing out shape.
+    count_include_pad indicates including or excluding padded input values in 
computation.
+    This operator accepts data layout specification.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    pool_size : Union[int, Tuple[int, int, int]]
+        The size of window for pooling. It is required to have length either 1 
or 3.
+
+    strides : Union[int, Tuple[int, int, int]]
+        The strides of pooling. It is required to have length either 1 or 3.
+
+    padding : Union[int, Tuple[int, ...]]
+        The padding for pooling. It is required to have length either 1, 3 or 
6.
+
+    dilation : Union[int, Tuple[int, int, int]]
+        The dilation of pooling. It is required to have length either 1 or 3.
+
+    ceil_mode : bool
+        A boolean indicating if use ceil or floor to compute the output shape.
+        By using ceil, every element in the input tensor will be covered by a 
sliding window.
+
+    count_include_pad : bool, optional
+        To include padding to compute the average.
+
+    layout : str
+        Layout of the input.
+
+    out_layout : Optional[str]
+        Layout of the output. If not specified, it is the same as data_layout
+
+    Returns
+    -------
+    result : Expr
+        The computed result.
+    """
+    if isinstance(pool_size, int):
+        pool_size = (pool_size, pool_size, pool_size)
+    if isinstance(strides, int):
+        strides = (strides, strides, strides)
+    if isinstance(dilation, int):
+        dilation = (dilation, dilation, dilation)
+    if isinstance(padding, int):
+        padding = (padding, padding, padding, padding, padding, padding)
+
+    return _ffi_api.avg_pool3d(  # type: ignore
+        data,
+        pool_size,
+        strides,
+        padding,
+        dilation,
+        ceil_mode,
+        count_include_pad,
+        layout,
+        out_layout,
     )
 
 
diff --git a/python/tvm/relax/transform/legalize_ops/grad.py 
b/python/tvm/relax/transform/legalize_ops/grad.py
index 1d527bea6a..4fde2a25c3 100644
--- a/python/tvm/relax/transform/legalize_ops/grad.py
+++ b/python/tvm/relax/transform/legalize_ops/grad.py
@@ -125,6 +125,7 @@ def _grad_max_pool2d_backward(bb: BlockBuilder, call: Call) 
-> Expr:
         padding=call.attrs.padding,
         pool_type="max",
         ceil_mode=call.attrs.ceil_mode,
+        count_include_pad=call.attrs.count_include_pad,
         layout=call.attrs.layout,
         primfunc_name_hint="max_pool2d_backward",
     )
@@ -144,6 +145,7 @@ def _grad_avg_pool2d_backward(bb: BlockBuilder, call: Call) 
-> Expr:
         padding=call.attrs.padding,
         pool_type="avg",
         ceil_mode=call.attrs.ceil_mode,
+        count_include_pad=call.attrs.count_include_pad,
         layout=call.attrs.layout,
         primfunc_name_hint="avg_pool2d_backward",
     )
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index f80d28099c..8f5407ff09 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -241,6 +241,29 @@ def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.nn.max_pool1d")
+def _nn_max_pool1d(bb: BlockBuilder, call: Call) -> Expr:
+    if call.attrs.out_layout != call.attrs.layout:
+        logging.info(
+            "TOPI max_pool1d does not support different input-output "
+            "layouts, and thus cannot be legalized by TOPI"
+        )
+        return call
+
+    return bb.call_te(
+        topi.nn.pool1d,
+        call.args[0],
+        kernel=call.attrs.pool_size,
+        stride=call.attrs.strides,
+        dilation=call.attrs.dilation,
+        padding=call.attrs.padding,
+        pool_type="max",
+        ceil_mode=call.attrs.ceil_mode,
+        layout=call.attrs.layout,
+        primfunc_name_hint="max_pool1d",
+    )
+
+
 @register_legalize("relax.nn.max_pool2d")
 def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr:
     if call.attrs.out_layout != call.attrs.layout:
@@ -264,6 +287,53 @@ def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr:
     )
 
 
+@register_legalize("relax.nn.max_pool3d")
+def _nn_max_pool3d(bb: BlockBuilder, call: Call) -> Expr:
+    if call.attrs.out_layout != call.attrs.layout:
+        logging.info(
+            "TOPI max_pool3d does not support different input-output "
+            "layouts, and thus cannot be legalized by TOPI"
+        )
+        return call
+
+    return bb.call_te(
+        topi.nn.pool3d,
+        call.args[0],
+        kernel=call.attrs.pool_size,
+        stride=call.attrs.strides,
+        dilation=call.attrs.dilation,
+        padding=call.attrs.padding,
+        pool_type="max",
+        ceil_mode=call.attrs.ceil_mode,
+        layout=call.attrs.layout,
+        primfunc_name_hint="max_pool3d",
+    )
+
+
+@register_legalize("relax.nn.avg_pool1d")
+def _nn_avg_pool1d(bb: BlockBuilder, call: Call) -> Expr:
+    if call.attrs.out_layout != call.attrs.layout:
+        logging.info(
+            "TOPI avg_pool1d does not support different input-output "
+            "layouts, and thus cannot be legalized by TOPI"
+        )
+        return call
+
+    return bb.call_te(
+        topi.nn.pool1d,
+        call.args[0],
+        kernel=call.attrs.pool_size,
+        stride=call.attrs.strides,
+        dilation=call.attrs.dilation,
+        padding=call.attrs.padding,
+        pool_type="avg",
+        ceil_mode=call.attrs.ceil_mode,
+        layout=call.attrs.layout,
+        count_include_pad=call.attrs.count_include_pad,
+        primfunc_name_hint="avg_pool1d",
+    )
+
+
 @register_legalize("relax.nn.avg_pool2d")
 def _nn_avg_pool2d(bb: BlockBuilder, call: Call) -> Expr:
     if call.attrs.out_layout != call.attrs.layout:
@@ -283,10 +353,35 @@ def _nn_avg_pool2d(bb: BlockBuilder, call: Call) -> Expr:
         pool_type="avg",
         ceil_mode=call.attrs.ceil_mode,
         layout=call.attrs.layout,
+        count_include_pad=call.attrs.count_include_pad,
         primfunc_name_hint="avg_pool2d",
     )
 
 
+@register_legalize("relax.nn.avg_pool3d")
+def _nn_avg_pool3d(bb: BlockBuilder, call: Call) -> Expr:
+    if call.attrs.out_layout != call.attrs.layout:
+        logging.info(
+            "TOPI avg_pool3d does not support different input-output "
+            "layouts, and thus cannot be legalized by TOPI"
+        )
+        return call
+
+    return bb.call_te(
+        topi.nn.pool3d,
+        call.args[0],
+        kernel=call.attrs.pool_size,
+        stride=call.attrs.strides,
+        dilation=call.attrs.dilation,
+        padding=call.attrs.padding,
+        pool_type="avg",
+        ceil_mode=call.attrs.ceil_mode,
+        layout=call.attrs.layout,
+        count_include_pad=call.attrs.count_include_pad,
+        primfunc_name_hint="avg_pool3d",
+    )
+
+
 @register_legalize("relax.nn.adaptive_avg_pool2d")
 def _nn_adaptive_avg_pool2d(bb: BlockBuilder, call: Call) -> Expr:
     if call.attrs.out_layout != call.attrs.layout:
diff --git a/python/tvm/topi/nn/pooling.py b/python/tvm/topi/nn/pooling.py
index b12c492ed8..a45480f12e 100644
--- a/python/tvm/topi/nn/pooling.py
+++ b/python/tvm/topi/nn/pooling.py
@@ -65,8 +65,8 @@ def pool_grad(
     padding,
     pool_type,
     ceil_mode=False,
-    layout="NCHW",
     count_include_pad=True,
+    layout="NCHW",
 ):
     """Gradient of pooling on height and width dimension of data.
        It decides the height and width dimension according to the layout 
string,
@@ -99,6 +99,9 @@ def pool_grad(
     ceil_mode : bool
         Whether to use ceil when calculating output size.
 
+    count_include_pad: bool
+        Whether include padding in the calculation when pool_type is 'avg'
+
     layout: string
         Layout of the input data.
         The layout is supposed to be composed of upper cases, lower cases and 
numbers,
@@ -108,8 +111,6 @@ def pool_grad(
         [batch_size, channel, height, width, channel_block],
         in which channel_block=16 is a split of dimension channel.
 
-    count_include_pad: bool
-        Whether include padding in the calculation when pool_type is 'avg'
 
     Returns
     -------
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index 6c81f5310a..865d419bca 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -25,12 +25,116 @@
 namespace tvm {
 namespace relax {
 
-/* relax.nn.max_pool2d and relax.nn.avg_pool2d */
+/* relax.nn.max_pool1d */
+TVM_REGISTER_NODE_TYPE(Pool1DAttrs);
+
+Expr MakePool1d(String op_name, Expr data, Array<IntImm> pool_size, 
Array<IntImm> strides,
+                Array<IntImm> padding, Array<IntImm> dilation, bool ceil_mode,
+                bool count_include_pad, String layout, Optional<String> 
out_layout) {
+  padding = GetCompletePadding1D(std::move(padding));
+
+  CHECK_EQ(pool_size.size(), 1)
+      << "The input pool_size length is expected to be 1. However, the given 
pool_size is "
+      << pool_size;
+  CHECK_EQ(strides.size(), 1)
+      << "The input strides length is expected to be 1. However, the given 
strides is " << strides;
+  CHECK_EQ(dilation.size(), 1)
+      << "The input dilation length is expected to be 1. However, the given 
dilation is "
+      << dilation;
+
+  auto attrs = make_object<Pool1DAttrs>();
+  attrs->pool_size = ConvertIntImmToInt64(pool_size);
+  attrs->strides = ConvertIntImmToInt64(strides);
+  attrs->padding = ConvertIntImmToInt64(padding);
+  attrs->dilation = ConvertIntImmToInt64(dilation);
+  attrs->ceil_mode = ceil_mode;
+  attrs->count_include_pad = count_include_pad;
+  attrs->layout = layout;
+  attrs->out_layout = out_layout.value_or(layout);
+  const Op& op = Op::Get(op_name);
+  return Call(op, {std::move(data)}, Attrs(attrs), {});
+}
+
+Expr max_pool1d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
+                Array<IntImm> dilation, bool ceil_mode, bool 
count_include_pad, String layout,
+                Optional<String> out_layout) {
+  return MakePool1d("relax.nn.max_pool1d", data, pool_size, strides, padding, 
dilation, ceil_mode,
+                    count_include_pad, layout, out_layout);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.max_pool1d").set_body_typed(max_pool1d);
+
+StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) {
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+
+  const auto* attrs = call->attrs.as<Pool1DAttrs>();
+  auto [data_layout, data2NCW] = CheckTensorLayout(call, ctx, attrs->layout,
+                                                   /*tgt_layout=*/"NCW",
+                                                   /*tensor_name=*/"data");
+  auto [out_layout, out2NCW] = CheckTensorLayout(call, ctx, attrs->out_layout,
+                                                 /*tgt_layout=*/"NCW",
+                                                 /*tensor_name=*/"output");
+
+  Optional<ShapeExpr> data_shape =
+      CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+  if (!data_shape.defined()) {
+    return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), 
data_sinfo->vdevice);
+  }
+
+  Array<PrimExpr> data_NCW_shape = 
data2NCW.ForwardShape(data_shape.value()->values);
+
+  PrimExpr input_w = data_NCW_shape[2];
+  PrimExpr kernel_w = attrs->pool_size[0];
+  PrimExpr padding_w = attrs->padding[0] + attrs->padding[1];
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  std::vector<PrimExpr> out_NCW_shape;
+  out_NCW_shape.resize(3);
+  out_NCW_shape[0] = data_NCW_shape[0];
+  out_NCW_shape[1] = data_NCW_shape[1];
+
+  PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w 
- 1) - 1;
+  if (attrs->ceil_mode) {
+    numerator_w += attrs->strides[1] - 1;
+  }
+  out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w, 
attrs->strides[0]) + 1);
+
+  Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
+  return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, 
data_sinfo->vdevice);
+}
+
+InferLayoutOutput InferLayoutPool1d(const Call& call,
+                                    const Map<String, Array<String>>& 
desired_layouts,
+                                    const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call);
+  ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+  ICHECK_EQ(tensor_sinfo->ndim, 3) << "Unsupported initial layout";
+  const auto* attrs = call->attrs.as<Pool1DAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+  ObjectPtr<Pool1DAttrs> new_attrs = make_object<Pool1DAttrs>(*attrs);
+  new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(3), 
layout->layout).name();
+  new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(3), 
layout->layout).name();
+  return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
+}
+
+TVM_REGISTER_OP("relax.nn.max_pool1d")
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor")
+    .set_attrs_type<Pool1DAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool1D)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool1d)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+/* relax.nn.max_pool2d */
 TVM_REGISTER_NODE_TYPE(Pool2DAttrs);
 
 Expr MakePool2d(String op_name, Expr data, Array<IntImm> pool_size, 
Array<IntImm> strides,
-                Array<IntImm> padding, Array<IntImm> dilation, bool ceil_mode, 
String layout,
-                Optional<String> out_layout) {
+                Array<IntImm> padding, Array<IntImm> dilation, bool ceil_mode,
+                bool count_include_pad, String layout, Optional<String> 
out_layout) {
   padding = GetCompletePadding2D(std::move(padding));
   if (pool_size.size() == 1) {
     pool_size.push_back(pool_size[0]);
@@ -57,6 +161,7 @@ Expr MakePool2d(String op_name, Expr data, Array<IntImm> 
pool_size, Array<IntImm
   attrs->padding = ConvertIntImmToInt64(padding);
   attrs->dilation = ConvertIntImmToInt64(dilation);
   attrs->ceil_mode = ceil_mode;
+  attrs->count_include_pad = count_include_pad;
   attrs->layout = layout;
   attrs->out_layout = out_layout.value_or(layout);
   const Op& op = Op::Get(op_name);
@@ -64,10 +169,10 @@ Expr MakePool2d(String op_name, Expr data, Array<IntImm> 
pool_size, Array<IntImm
 }
 
 Expr max_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
-                Array<IntImm> dilation, bool ceil_mode, String layout,
+                Array<IntImm> dilation, bool ceil_mode, bool 
count_include_pad, String layout,
                 Optional<String> out_layout) {
   return MakePool2d("relax.nn.max_pool2d", data, pool_size, strides, padding, 
dilation, ceil_mode,
-                    layout, out_layout);
+                    count_include_pad, layout, out_layout);
 }
 
 TVM_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(max_pool2d);
@@ -143,11 +248,159 @@ TVM_REGISTER_OP("relax.nn.max_pool2d")
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.nn.max_pool3d */
+TVM_REGISTER_NODE_TYPE(Pool3DAttrs);
+
+Expr MakePool3d(String op_name, Expr data, Array<IntImm> pool_size, 
Array<IntImm> strides,
+                Array<IntImm> padding, Array<IntImm> dilation, bool ceil_mode,
+                bool count_include_pad, String layout, Optional<String> 
out_layout) {
+  padding = GetCompletePadding3D(std::move(padding));
+  if (pool_size.size() == 1) {
+    pool_size.push_back(pool_size[0]);
+    pool_size.push_back(pool_size[0]);
+  }
+  if (strides.size() == 1) {
+    strides.push_back(strides[0]);
+    strides.push_back(strides[0]);
+  }
+  if (dilation.size() == 1) {
+    dilation.push_back(dilation[0]);
+    dilation.push_back(dilation[0]);
+  }
+
+  CHECK_EQ(pool_size.size(), 3)
+      << "The input pool_size length is expected to be 3. However, the given 
pool_size is "
+      << pool_size;
+  CHECK_EQ(strides.size(), 3)
+      << "The input strides length is expected to be 3. However, the given 
strides is " << strides;
+  CHECK_EQ(dilation.size(), 3)
+      << "The input dilation length is expected to be 3. However, the given 
dilation is "
+      << dilation;
+
+  auto attrs = make_object<Pool3DAttrs>();
+  attrs->pool_size = ConvertIntImmToInt64(pool_size);
+  attrs->strides = ConvertIntImmToInt64(strides);
+  attrs->padding = ConvertIntImmToInt64(padding);
+  attrs->dilation = ConvertIntImmToInt64(dilation);
+  attrs->ceil_mode = ceil_mode;
+  attrs->count_include_pad = count_include_pad;
+  attrs->layout = layout;
+  attrs->out_layout = out_layout.value_or(layout);
+  const Op& op = Op::Get(op_name);
+  return Call(op, {std::move(data)}, Attrs(attrs), {});
+}
+
+Expr max_pool3d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
+                Array<IntImm> dilation, bool ceil_mode, bool 
count_include_pad, String layout,
+                Optional<String> out_layout) {
+  return MakePool3d("relax.nn.max_pool3d", data, pool_size, strides, padding, 
dilation, ceil_mode,
+                    count_include_pad, layout, out_layout);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.max_pool3d").set_body_typed(max_pool3d);
+
+StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) {
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+
+  const auto* attrs = call->attrs.as<Pool3DAttrs>();
+  auto [data_layout, data2NCDHW] = CheckTensorLayout(call, ctx, attrs->layout,
+                                                     /*tgt_layout=*/"NCDHW",
+                                                     /*tensor_name=*/"data");
+  auto [out_layout, out2NCDHW] = CheckTensorLayout(call, ctx, 
attrs->out_layout,
+                                                   /*tgt_layout=*/"NCDHW",
+                                                   /*tensor_name=*/"output");
+
+  Optional<ShapeExpr> data_shape =
+      CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+  if (!data_shape.defined()) {
+    return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(), 
data_sinfo->vdevice);
+  }
+
+  Array<PrimExpr> data_NCDHW_shape = 
data2NCDHW.ForwardShape(data_shape.value()->values);
+
+  PrimExpr input_d = data_NCDHW_shape[2];
+  PrimExpr input_h = data_NCDHW_shape[3];
+  PrimExpr input_w = data_NCDHW_shape[4];
+  PrimExpr kernel_d = attrs->pool_size[0];
+  PrimExpr kernel_h = attrs->pool_size[1];
+  PrimExpr kernel_w = attrs->pool_size[2];
+  PrimExpr padding_d = attrs->padding[0] + attrs->padding[3];
+  PrimExpr padding_h = attrs->padding[1] + attrs->padding[4];
+  PrimExpr padding_w = attrs->padding[2] + attrs->padding[5];
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  std::vector<PrimExpr> out_NCDHW_shape;
+  out_NCDHW_shape.resize(5);
+  out_NCDHW_shape[0] = data_NCDHW_shape[0];
+  out_NCDHW_shape[1] = data_NCDHW_shape[1];
+
+  PrimExpr numerator_d = input_d + padding_d - attrs->dilation[0] * (kernel_d 
- 1) - 1;
+  PrimExpr numerator_h = input_h + padding_h - attrs->dilation[1] * (kernel_h 
- 1) - 1;
+  PrimExpr numerator_w = input_w + padding_w - attrs->dilation[2] * (kernel_w 
- 1) - 1;
+  if (attrs->ceil_mode) {
+    numerator_d += attrs->strides[0] - 1;
+    numerator_h += attrs->strides[1] - 1;
+    numerator_w += attrs->strides[2] - 1;
+  }
+  out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d, 
attrs->strides[0]) + 1);
+  out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h, 
attrs->strides[1]) + 1);
+  out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w, 
attrs->strides[2]) + 1);
+
+  Array<PrimExpr> out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape);
+  return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype, 
data_sinfo->vdevice);
+}
+
+InferLayoutOutput InferLayoutPool3d(const Call& call,
+                                    const Map<String, Array<String>>& 
desired_layouts,
+                                    const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+  const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call);
+  ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+  ICHECK_EQ(tensor_sinfo->ndim, 5) << "Unsupported initial layout";
+  const auto* attrs = call->attrs.as<Pool3DAttrs>();
+  ICHECK(attrs) << "Invalid Call";
+
+  LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+  ObjectPtr<Pool3DAttrs> new_attrs = make_object<Pool3DAttrs>(*attrs);
+  new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(5), 
layout->layout).name();
+  new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(5), 
layout->layout).name();
+  return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
+}
+
+TVM_REGISTER_OP("relax.nn.max_pool3d")
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor")
+    .set_attrs_type<Pool3DAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool3D)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool3d)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+/* relax.nn.avg_pool1d */
+Expr avg_pool1d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
+                Array<IntImm> dilation, bool ceil_mode, bool 
count_include_pad, String layout,
+                Optional<String> out_layout) {
+  return MakePool1d("relax.nn.avg_pool1d", data, pool_size, strides, padding, 
dilation, ceil_mode,
+                    count_include_pad, layout, out_layout);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool1d").set_body_typed(avg_pool1d);
+
+TVM_REGISTER_OP("relax.nn.avg_pool1d")
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor")
+    .set_attrs_type<Pool1DAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool1D)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool1d)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+/* relax.nn.avg_pool2d */
 Expr avg_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
-                Array<IntImm> dilation, bool ceil_mode, String layout,
+                Array<IntImm> dilation, bool ceil_mode, bool 
count_include_pad, String layout,
                 Optional<String> out_layout) {
   return MakePool2d("relax.nn.avg_pool2d", data, pool_size, strides, padding, 
dilation, ceil_mode,
-                    layout, out_layout);
+                    count_include_pad, layout, out_layout);
 }
 
 TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool2d").set_body_typed(avg_pool2d);
@@ -161,6 +414,25 @@ TVM_REGISTER_OP("relax.nn.avg_pool2d")
     .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.nn.avg_pool3d */
+Expr avg_pool3d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
+                Array<IntImm> dilation, bool ceil_mode, bool 
count_include_pad, String layout,
+                Optional<String> out_layout) {
+  return MakePool3d("relax.nn.avg_pool3d", data, pool_size, strides, padding, 
dilation, ceil_mode,
+                    count_include_pad, layout, out_layout);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool3d").set_body_typed(avg_pool3d);
+
+TVM_REGISTER_OP("relax.nn.avg_pool3d")
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor")
+    .set_attrs_type<Pool3DAttrs>()
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool3D)
+    .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool3d)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 /* relax.nn.adaptive_avg_pool2d */
 TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs);
 
diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h
index 63d2e76772..7fd66f2b44 100644
--- a/src/relax/op/nn/pooling.h
+++ b/src/relax/op/nn/pooling.h
@@ -34,11 +34,13 @@ namespace relax {
 
 /*! \brief 2D maximum pooling operator. */
 Expr max_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
-                Array<IntImm> dilation, bool ceil_mode, String layout, 
Optional<String> out_layout);
+                Array<IntImm> dilation, bool ceil_mode, bool 
count_include_pad, String layout,
+                Optional<String> out_layout);
 
 /*! \brief 2D average pooling operator. */
 Expr avg_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides, 
Array<IntImm> padding,
-                Array<IntImm> dilation, bool ceil_mode, String layout, 
Optional<String> out_layout);
+                Array<IntImm> dilation, bool ceil_mode, bool 
count_include_pad, String layout,
+                Optional<String> out_layout);
 
 /*! \brief 2D adaptive average pooling operator. */
 Expr adaptive_avg_pool2d(Expr data, Optional<Array<IntImm>> output_size, 
String layout,
diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc
index 6f30684460..70114093e3 100644
--- a/src/relax/op/tensor/grad.cc
+++ b/src/relax/op/tensor/grad.cc
@@ -130,13 +130,15 @@ TVM_REGISTER_OP("relax.grad.nll_loss_backward")
 /* relax.grad.max_pool2d_backward */
 Expr max_pool2d_backward(Expr output_grad, Expr data, Array<IntImm> pool_size,
                          Array<IntImm> strides, Array<IntImm> padding, 
Array<IntImm> dilation,
-                         bool ceil_mode, String layout, Optional<String> 
out_layout) {
+                         bool ceil_mode, bool count_include_pad, String layout,
+                         Optional<String> out_layout) {
   auto attrs = make_object<Pool2DAttrs>();
   attrs->pool_size = std::move(pool_size);
   attrs->strides = ConvertIntImmToInt64(strides);
   attrs->padding = ConvertIntImmToInt64(padding);
   attrs->dilation = ConvertIntImmToInt64(dilation);
   attrs->ceil_mode = ceil_mode;
+  attrs->count_include_pad = count_include_pad;
   attrs->layout = layout;
   attrs->out_layout = out_layout.value_or(layout);
   static const Op& op = Op::Get("relax.grad.max_pool2d_backward");
@@ -160,13 +162,15 @@ TVM_REGISTER_OP("relax.grad.max_pool2d_backward")
 /* relax.grad.avg_pool2d_backward */
 Expr avg_pool2d_backward(Expr output_grad, Expr data, Array<IntImm> pool_size,
                          Array<IntImm> strides, Array<IntImm> padding, 
Array<IntImm> dilation,
-                         bool ceil_mode, String layout, Optional<String> 
out_layout) {
+                         bool ceil_mode, bool count_include_pad, String layout,
+                         Optional<String> out_layout) {
   auto attrs = make_object<Pool2DAttrs>();
   attrs->pool_size = std::move(pool_size);
   attrs->strides = ConvertIntImmToInt64(strides);
   attrs->padding = ConvertIntImmToInt64(padding);
   attrs->dilation = ConvertIntImmToInt64(dilation);
   attrs->ceil_mode = ceil_mode;
+  attrs->count_include_pad = count_include_pad;
   attrs->layout = layout;
   attrs->out_layout = out_layout.value_or(layout);
   static const Op& op = Op::Get("relax.grad.avg_pool2d_backward");
diff --git a/src/relax/op/tensor/grad.h b/src/relax/op/tensor/grad.h
index 886516020d..228de315af 100644
--- a/src/relax/op/tensor/grad.h
+++ b/src/relax/op/tensor/grad.h
@@ -48,13 +48,15 @@ Expr nll_loss_backward(Expr output_grad, Expr predictions, 
Expr targets, Optiona
  * relax.max_pool2d. Returns the gradient w.r.t. data. */
 Expr max_pool2d_backward(Expr output_grad, Expr data, Array<IntImm> pool_size,
                          Array<IntImm> strides, Array<IntImm> padding, 
Array<IntImm> dilation,
-                         bool ceil_mode, String layout, Optional<String> 
out_layout);
+                         bool ceil_mode, bool count_include_pad, String layout,
+                         Optional<String> out_layout);
 
 /*! \brief Backward operator of relax.avg_pool2d. All parameters except 
output_grad is the same as
  * relax.avg_pool2d. Returns the gradient w.r.t. data. */
 Expr avg_pool2d_backward(Expr output_grad, Expr data, Array<IntImm> pool_size,
                          Array<IntImm> strides, Array<IntImm> padding, 
Array<IntImm> dilation,
-                         bool ceil_mode, String layout, Optional<String> 
out_layout);
+                         bool ceil_mode, bool count_include_pad, String layout,
+                         Optional<String> out_layout);
 
 /*! \brief Backward operator of relax.take. All parameters except output_grad 
is the same as
  * relax.take. Returns the gradient w.r.t. data. */
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 473766b749..32778cdd55 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1587,69 +1587,194 @@ def test_batch_norm():
     check_correctness(model, opset=15)
 
 
-def test_max_pool():
-    # Pool2D
-    verify_unary(
-        "MaxPool",
-        [1, 1, 32, 32],
-        dict(
-            auto_pad="NOTSET",
-            kernel_shape=[3, 3],
-            pads=[1, 1, 1, 1],
-            strides=[1, 1],
-        ),
-    )
-    # Pool2D with stride
-    verify_unary(
-        "MaxPool",
-        [1, 1, 32, 32],
-        dict(
-            auto_pad="NOTSET",
-            kernel_shape=[3, 3],
-            pads=[1, 1, 1, 1],
-            strides=[2, 2],
-        ),
-    )
-    # Pool2D with stride and autopadding
-    verify_unary(
-        "MaxPool",
-        [1, 1, 32, 32],
-        dict(
-            auto_pad="SAME_UPPER",
-            kernel_shape=[3, 7],
-            pads=None,
-            strides=[3, 2],
-        ),
-    )
-    verify_unary(
-        "MaxPool",
-        [1, 1, 32, 32],
-        dict(
-            auto_pad="SAME_LOWER",
-            kernel_shape=[3, 3],
-            pads=None,
-            strides=[2, 2],
-        ),
-    )
-    verify_unary(
-        "MaxPool",
-        [1, 1, 32, 32],
-        dict(
-            auto_pad="VALID",
-            kernel_shape=[3, 3],
-            pads=None,
-            strides=[2, 2],
-        ),
-    )
-    verify_unary(
-        "MaxPool",
-        [1, 1, 32, 32],
-        dict(
-            auto_pad="SAME_UPPER",
-            kernel_shape=[3, 3],
-            pads=None,
-        ),
-    )
+def test_maxpool_and_averagepool():
+    for pool_name in ["MaxPool", "AveragePool"]:
+        # Pool1D
+        verify_unary(
+            pool_name,
+            [1, 1, 32],
+            dict(
+                auto_pad="NOTSET",
+                kernel_shape=[3],
+                pads=[1, 1],
+                strides=[1],
+            ),
+        )
+        # Pool1D with stride
+        verify_unary(
+            pool_name,
+            [1, 1, 32],
+            dict(
+                auto_pad="NOTSET",
+                kernel_shape=[3],
+                pads=[1, 2],
+                strides=[2],
+            ),
+        )
+        # Pool1D with stride and autopadding
+        verify_unary(
+            pool_name,
+            [1, 1, 32],
+            dict(
+                auto_pad="SAME_UPPER",
+                kernel_shape=[7],
+                pads=None,
+                strides=[2],
+            ),
+        )
+        verify_unary(
+            pool_name,
+            [1, 1, 32],
+            dict(
+                auto_pad="SAME_LOWER",
+                kernel_shape=[4],
+                pads=None,
+                strides=[4],
+            ),
+        )
+        verify_unary(
+            pool_name,
+            [1, 1, 32],
+            dict(
+                auto_pad="VALID",
+                kernel_shape=[5],
+                pads=None,
+                strides=[5],
+            ),
+        )
+        verify_unary(
+            pool_name,
+            [1, 1, 32],
+            dict(
+                auto_pad="SAME_UPPER",
+                kernel_shape=[3],
+                pads=None,
+            ),
+        )
+        # Pool2D
+        verify_unary(
+            pool_name,
+            [1, 1, 32, 32],
+            dict(
+                auto_pad="NOTSET",
+                kernel_shape=[3, 3],
+                pads=[1, 1, 1, 1],
+                strides=[1, 1],
+            ),
+        )
+        # Pool2D with stride
+        verify_unary(
+            pool_name,
+            [1, 1, 32, 32],
+            dict(
+                auto_pad="NOTSET",
+                kernel_shape=[3, 3],
+                pads=[1, 1, 1, 1],
+                strides=[2, 2],
+            ),
+        )
+        # Pool2D with stride and autopadding
+        verify_unary(
+            pool_name,
+            [1, 1, 32, 32],
+            dict(
+                auto_pad="SAME_UPPER",
+                kernel_shape=[3, 7],
+                pads=None,
+                strides=[3, 2],
+            ),
+        )
+        verify_unary(
+            pool_name,
+            [1, 1, 32, 32],
+            dict(
+                auto_pad="SAME_LOWER",
+                kernel_shape=[3, 3],
+                pads=None,
+                strides=[2, 2],
+            ),
+        )
+        verify_unary(
+            pool_name,
+            [1, 1, 32, 32],
+            dict(
+                auto_pad="VALID",
+                kernel_shape=[3, 3],
+                pads=None,
+                strides=[2, 2],
+            ),
+        )
+        verify_unary(
+            pool_name,
+            [1, 1, 32, 32],
+            dict(
+                auto_pad="SAME_UPPER",
+                kernel_shape=[3, 3],
+                pads=None,
+            ),
+        )
+        # Pool3D
+        verify_unary(
+            pool_name,
+            [1, 1, 32, 32, 32],
+            dict(
+                auto_pad="NOTSET",
+                kernel_shape=[3, 3, 4],
+                pads=[1, 2, 1, 1, 2, 2],
+                strides=[1, 1, 1],
+            ),
+        )
+        # Pool3D with stride
+        verify_unary(
+            pool_name,
+            [1, 1, 32, 32, 32],
+            dict(
+                auto_pad="NOTSET",
+                kernel_shape=[3, 4, 3],
+                pads=[1, 1, 1, 1, 1, 2],
+                strides=[2, 2, 3],
+            ),
+        )
+        # Pool3D with stride and autopadding
+        verify_unary(
+            pool_name,
+            [1, 1, 32, 32, 32],
+            dict(
+                auto_pad="SAME_UPPER",
+                kernel_shape=[4, 3, 3],
+                pads=None,
+                strides=[3, 2, 2],
+            ),
+        )
+        verify_unary(
+            pool_name,
+            [1, 1, 32, 32, 32],
+            dict(
+                auto_pad="SAME_LOWER",
+                kernel_shape=[3, 3, 4],
+                pads=None,
+                strides=[2, 2, 2],
+            ),
+        )
+        verify_unary(
+            pool_name,
+            [1, 1, 32, 32, 32],
+            dict(
+                auto_pad="VALID",
+                kernel_shape=[3, 3, 5],
+                pads=None,
+                strides=[2, 2, 3],
+            ),
+        )
+        verify_unary(
+            pool_name,
+            [1, 1, 32, 32, 32],
+            dict(
+                auto_pad="SAME_UPPER",
+                kernel_shape=[3, 3, 5],
+                pads=None,
+            ),
+        )
 
 
 def test_global_average_pool():
diff --git a/tests/python/relax/test_op_gradient_numeric.py 
b/tests/python/relax/test_op_gradient_numeric.py
index bc5cb0f5be..acf0f615dd 100644
--- a/tests/python/relax/test_op_gradient_numeric.py
+++ b/tests/python/relax/test_op_gradient_numeric.py
@@ -802,11 +802,17 @@ def test_conv2d(target, dev, c2d_shape1, c2d_shape2, 
c2d_kwargs):
     ),
     (
         (3, 3),
-        {"strides": (2, 2), "padding": (1, 2), "dilation": (1, 1)},
+        {"strides": (2, 2), "padding": (1, 2), "dilation": (1, 1), 
"count_include_pad": True},
     ),
     (
         (5, 5),
-        {"strides": (2, 2), "padding": (2, 1), "dilation": (1, 1), 
"ceil_mode": True},
+        {
+            "strides": (2, 2),
+            "padding": (2, 1),
+            "dilation": (1, 1),
+            "ceil_mode": True,
+            "count_include_pad": True,
+        },
     ),
 )
 
diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py 
b/tests/python/relax/test_transform_legalize_ops_grad.py
index 19d1a106f8..f13748d2fa 100644
--- a/tests/python/relax/test_transform_legalize_ops_grad.py
+++ b/tests/python/relax/test_transform_legalize_ops_grad.py
@@ -282,8 +282,7 @@ def test_avg_pool2d_backward():
                     T.writes(T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3])
                     with T.init():
                         T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0)
-                    T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 < 
T.int64(3), T.int64(0), T.Div((v_ax2 - T.int64(3)), T.int64(2)) + T.int64(1)) 
<= T.Div((v_ax2 + T.int64(2)), T.int64(2)) - v_wh and T.Div((v_ax2 + 
T.int64(2)), T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4), 
T.int64(0), T.Div((v_ax3 - T.int64(4)), T.int64(2)) + T.int64(1)) <= 
T.Div((v_ax3 + T.int64(1)), T.int64(2)) - v_ww and T.Div((v_ax [...]
-
+                    T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = 
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 < 
T.int64(3), T.int64(0), T.Div(v_ax2 - T.int64(3), T.int64(2)) + T.int64(1)) <= 
T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh and T.Div(v_ax2 + T.int64(2), 
T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4), T.int64(0), 
T.Div(v_ax3 - T.int64(4), T.int64(2)) + T.int64(1)) <= T.Div(v_ax3 + 
T.int64(1), T.int64(2)) - v_ww and T.Div(v_ax3 + T.int64 [...]
         @R.function
         def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data: 
R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10), 
dtype="float32"):
             cls = Expected
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 29171daaae..92d139d23b 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -743,7 +743,7 @@ def test_avg_pool2d():
                     T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                     T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
                     T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
-                    pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, 
v_ax1, v_ax2, v_ax3] / T.Cast("float32", (T.min(T.int64(1), T.int64(112) - 
v_ax1 * T.int64(2)) + T.int64(2)) * (T.min(T.int64(1), T.int64(112) - v_ax2 * 
T.int64(2)) + T.int64(2)))
+                    pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, 
v_ax1, v_ax2, v_ax3] / T.Cast("float32", T.max((T.min(v_ax1 * T.int64(2) + 
T.int64(1), T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax1 * T.int64(2), 
T.int64(0)) - v_ax1 * T.int64(2)) * (T.min(v_ax2 * T.int64(2) + T.int64(1), 
T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax2 * T.int64(2), T.int64(0)) 
- v_ax2 * T.int64(2)), T.int64(1)))
 
         @R.function
         def main(x: R.Tensor((4, 112, 112, 6), dtype="float32")) -> 
R.Tensor((4, 56, 56, 6), dtype="float32"):
@@ -785,8 +785,7 @@ def test_avg_pool2d_NCHW16c():
                     T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                     T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
                     T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
-                    pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = 
pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] / T.Cast("float32", 
(T.min(T.int64(2), T.int64(111) - v_ax2) + T.int64(1)) * (T.min(T.int64(2), 
T.int64(111) - v_ax3) + T.int64(1)))
-
+                    pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = 
pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] / T.Cast("float32", 
T.max((T.min(T.int64(2), T.int64(111) - v_ax2) + T.int64(1) - T.max(T.int64(0) 
- v_ax2, T.int64(0))) * (T.min(T.int64(2), T.int64(111) - v_ax3) + T.int64(1) - 
T.max(T.int64(0) - v_ax3, T.int64(0))), T.int64(1)))
         @R.function
         def main(x: R.Tensor((4, 4, 112, 112, 16), dtype="float32")) -> 
R.Tensor((4, 4, 110, 110, 16), dtype="float32"):
             gv = R.call_tir(Expected.avg_pool2d, (x,), out_sinfo=R.Tensor((4, 
4, 110, 110, 16), dtype="float32"))
@@ -834,7 +833,7 @@ def test_avg_pool2d_ceil_mode():
                     T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                     T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
                     T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
-                    pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, 
v_ax1, v_ax2, v_ax3] / T.Cast("float32", (T.min(T.int64(1), T.int64(112) - 
v_ax2 * T.int64(3)) + T.int64(2)) * (T.min(T.int64(1), T.int64(112) - v_ax3 * 
T.int64(3)) + T.int64(2)))
+                    pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0, 
v_ax1, v_ax2, v_ax3] / T.Cast("float32", T.max((T.min(v_ax2 * T.int64(3) + 
T.int64(1), T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax2 * T.int64(3), 
T.int64(0)) - v_ax2 * T.int64(3)) * (T.min(v_ax3 * T.int64(3) + T.int64(1), 
T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax3 * T.int64(3), T.int64(0)) 
- v_ax3 * T.int64(3)), T.int64(1)))
 
         @R.function
         def main(x: R.Tensor((4, 6, 112, 112), dtype="float32")) -> 
R.Tensor((4, 6, 38, 38), dtype="float32"):

Reply via email to