This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new af82d97043 [Relax][Frontend][Onnx] support MaxPool1/2/3D and
AveragePool1/2/3D (#16681)
af82d97043 is described below
commit af82d970436d25018d216bd02dd70fae9d5e6e83
Author: chengven027-intellif <[email protected]>
AuthorDate: Sun Mar 10 20:47:24 2024 +0800
[Relax][Frontend][Onnx] support MaxPool1/2/3D and AveragePool1/2/3D (#16681)
support MaxPool1/2/3D and AveragePool1/2/3D
Co-authored-by: cheng wen <chengven027-intellif>
---
include/tvm/relax/attrs/nn.h | 80 +++++
python/tvm/relax/frontend/onnx/onnx_frontend.py | 102 ++++--
python/tvm/relax/op/_op_gradient.py | 2 +
python/tvm/relax/op/grad/grad.py | 24 +-
python/tvm/relax/op/nn/__init__.py | 4 +
python/tvm/relax/op/nn/nn.py | 354 ++++++++++++++++++++-
python/tvm/relax/transform/legalize_ops/grad.py | 2 +
python/tvm/relax/transform/legalize_ops/nn.py | 95 ++++++
python/tvm/topi/nn/pooling.py | 7 +-
src/relax/op/nn/pooling.cc | 286 ++++++++++++++++-
src/relax/op/nn/pooling.h | 6 +-
src/relax/op/tensor/grad.cc | 8 +-
src/relax/op/tensor/grad.h | 6 +-
tests/python/relax/test_frontend_onnx.py | 251 +++++++++++----
tests/python/relax/test_op_gradient_numeric.py | 10 +-
.../relax/test_transform_legalize_ops_grad.py | 3 +-
.../python/relax/test_transform_legalize_ops_nn.py | 7 +-
17 files changed, 1123 insertions(+), 124 deletions(-)
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index dd63a70bc4..0bb2dcaab5 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -254,6 +254,43 @@ struct Conv2DTransposeAttrs : public
tvm::AttrsNode<Conv2DTransposeAttrs> {
}
}; // struct Conv2DTransposeAttrs
+/*! \brief Attributes used in max_pool1d and avg_pool1d operator */
+struct Pool1DAttrs : public tvm::AttrsNode<Pool1DAttrs> {
+ Array<IntImm> pool_size;
+ Array<IntImm> strides;
+ Array<IntImm> padding;
+ Array<IntImm> dilation;
+ bool ceil_mode;
+ bool count_include_pad;
+ String layout;
+ String out_layout;
+
+ TVM_DECLARE_ATTRS(Pool1DAttrs, "relax.attrs.Pool1DAttrs") {
+ TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
+ TVM_ATTR_FIELD(strides).describe("Specifies the strides of the
convolution.");
+ TVM_ATTR_FIELD(dilation).describe("Specifies the dilation of the
convolution.");
+ TVM_ATTR_FIELD(padding).describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "two int : padding width in the order of (left, right)");
+ TVM_ATTR_FIELD(ceil_mode).describe(
+ "A boolean indicating if use ceil or floor to compute the output
shape. By using ceil, "
+ "every element in the input tensor will be covered by a sliding
window.");
+ TVM_ATTR_FIELD(count_include_pad)
+ .describe("When true, will include padding to compute the average");
+ TVM_ATTR_FIELD(layout).set_default("NCW").describe(
+ "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
+ "'N', 'C', 'W' stands for batch, channel, and width"
+ "dimensions respectively. Pooling is applied on the 'W' dimensions.");
+ TVM_ATTR_FIELD(out_layout)
+ .describe(
+ "Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
+ "'N', 'C', 'W' stands for batch, channel, and width"
+ "dimensions respectively. Pooling is applied on the 'W'
dimensions.");
+ }
+}; // struct Pool1dAttrs
+
/*! \brief Attributes used in max_pool2d and avg_pool2d operator */
struct Pool2DAttrs : public tvm::AttrsNode<Pool2DAttrs> {
Array<IntImm> pool_size;
@@ -261,6 +298,7 @@ struct Pool2DAttrs : public tvm::AttrsNode<Pool2DAttrs> {
Array<IntImm> padding;
Array<IntImm> dilation;
bool ceil_mode;
+ bool count_include_pad;
String layout;
String out_layout;
@@ -277,6 +315,8 @@ struct Pool2DAttrs : public tvm::AttrsNode<Pool2DAttrs> {
TVM_ATTR_FIELD(ceil_mode).describe(
"A boolean indicating if use ceil or floor to compute the output
shape. By using ceil, "
"every element in the input tensor will be covered by a sliding
window.");
+ TVM_ATTR_FIELD(count_include_pad)
+ .describe("When true, will include padding to compute the average");
TVM_ATTR_FIELD(layout).describe(
"Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
"'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
@@ -291,6 +331,46 @@ struct Pool2DAttrs : public tvm::AttrsNode<Pool2DAttrs> {
}
}; // struct Pool2dAttrs
+/*! \brief Attributes used in max_pool3d and avg_pool3d operator */
+struct Pool3DAttrs : public tvm::AttrsNode<Pool3DAttrs> {
+ Array<IntImm> pool_size;
+ Array<IntImm> strides;
+ Array<IntImm> padding;
+ Array<IntImm> dilation;
+ bool ceil_mode;
+ bool count_include_pad;
+ String layout;
+ String out_layout;
+
+ TVM_DECLARE_ATTRS(Pool3DAttrs, "relax.attrs.Pool3DAttrs") {
+ TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
+ TVM_ATTR_FIELD(strides).describe("Specifies the strides of the
convolution.");
+ TVM_ATTR_FIELD(dilation).describe("Specifies the dilation of the
convolution.");
+ TVM_ATTR_FIELD(padding).describe(
+ "If padding is non-zero, then the input is implicitly zero-padded"
+ "Padding support both symmetric and asymmetric as"
+ "one int : same padding used on all sides"
+ "three int : back, bottom, right will use same padding as front, top,
left"
+ "four int : padding width in the order of (front, top, left, back,
bottom, right)");
+ TVM_ATTR_FIELD(ceil_mode).describe(
+ "A boolean indicating if use ceil or floor to compute the output
shape. By using ceil, "
+ "every element in the input tensor will be covered by a sliding
window.");
+ TVM_ATTR_FIELD(count_include_pad)
+ .describe("When true, will include padding to compute the average");
+ TVM_ATTR_FIELD(layout).describe(
+ "Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and
width"
+ "dimensions respectively. Pooling is applied on the 'D', 'H' and"
+ "'W' dimensions.");
+ TVM_ATTR_FIELD(out_layout)
+ .describe(
+ "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
+ "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height,
and width"
+ "dimensions respectively. Pooling is applied on the 'D', 'H' and"
+ "'W' dimensions.");
+ }
+}; // struct Pool3dAttrs
+
/*! \brief Attributes for 2d adaptive pool operator */
struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
Optional<Array<IntImm>> output_size;
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 092e73baa1..a047e8701c 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1438,21 +1438,40 @@ class BatchNormalization(OnnxOpConverter):
)
-class MaxPool(OnnxOpConverter):
- """Converts an onnx MaxPool node into an equivalent Relax expression."""
+class Pool(OnnxOpConverter):
+ """A helper class for pool op converters."""
+
+ name = ""
@classmethod
- def _impl_v12(cls, bb, inputs, attr, params):
+ def get_pad_pair(cls, input1d, kernel1d, stride1d, mode):
+ """infer pad size"""
+ if input1d % stride1d == 0:
+ pad = max(kernel1d - stride1d, 0)
+ else:
+ pad = max(kernel1d - (input1d % stride1d), 0)
+ pad_before = pad // 2
+ pad_after = pad - pad_before
+ if "LOWER" in mode:
+ return [pad_after, pad_before]
+ return [pad_before, pad_after]
+
+ @classmethod
+ def _impl_v1(cls, bb, inputs, attr, params):
# Unpack inputs and attributes.
data = inputs[0]
+ input_shape = data.struct_info.shape
+ ndim = len(input_shape)
+
auto_pad = attr.get("auto_pad", b"NOTSET").decode("utf-8")
ceil_mode = attr.get("ceil_mode", 0)
- dilations = attr.get("dilations", [1, 1])
+ dilations = attr.get("dilations", [1] * (ndim - 2))
kernel_shape = attr.get("kernel_shape")
pads = attr.get("pads", 0)
- strides = attr.get("strides", [1, 1])
+ strides = attr.get("strides", [1] * (ndim - 2))
+
+ assert len(kernel_shape) in [1, 2, 3], "Currently only 1D/2D/3D/
pooling is supported."
- assert len(kernel_shape) == 2, "Currently only 2D pooling is
supported."
assert auto_pad in [
"NOTSET",
"SAME_UPPER",
@@ -1461,34 +1480,40 @@ class MaxPool(OnnxOpConverter):
], f"Value {auto_pad} in attribute auto_pad is invalid."
if auto_pad in ("SAME_UPPER", "SAME_LOWER"):
- input_spatial_shape = cls._get_input_spatial_shape(data)
- output_spatial_shape = [0 for _ in input_spatial_shape]
-
- pads = _np.array([(0, 0) for _ in range(len(kernel_shape))])
+ pads = []
+ if cls.name == "avg_pool":
+ for axis in range(len(input_shape) - 2):
+ axis_shape = input_shape[2 + axis]
+ stride = strides[axis]
+ kernel = kernel_shape[axis]
+ pad = cls.get_pad_pair(axis_shape, kernel, stride,
auto_pad)
+ pads.append(pad)
+ else:
+ input_spatial_shape = cls._get_input_spatial_shape(data)
+ output_spatial_shape = [0 for _ in input_spatial_shape]
+
+ for i, _ in enumerate(input_spatial_shape):
+ if auto_pad == "SAME_UPPER":
+ output_spatial_shape[i] =
int(_np.ceil(input_spatial_shape[i] / strides[i]))
+ else:
+ output_spatial_shape[i] = int(
+ _np.floor(input_spatial_shape[i] / strides[i])
+ )
+ pad_i = (
+ (output_spatial_shape[i] - 1) * strides[i]
+ + ((kernel_shape[i] - 1) * dilations[i] + 1)
+ - input_spatial_shape[i]
+ )
- for i, _ in enumerate(input_spatial_shape):
- if auto_pad == "SAME_UPPER":
- output_spatial_shape[i] =
int(_np.ceil(input_spatial_shape[i] / strides[i]))
- else:
- output_spatial_shape[i] =
int(_np.floor(input_spatial_shape[i] / strides[i]))
- pad_i = (
- (output_spatial_shape[i] - 1) * strides[i]
- + ((kernel_shape[i] - 1) * dilations[i] + 1)
- - input_spatial_shape[i]
- )
- if auto_pad == "SAME_UPPER":
- pads[i, 0] = pad_i // 2
- pads[i, 1] = pad_i - pads[i, 0]
- else:
- pads[i, 1] = pad_i // 2
- pads[i, 0] = pad_i - pads[i, 1]
+ if auto_pad == "SAME_UPPER":
+ pads.append([pad_i // 2, pad_i - pad_i // 2])
+ else:
+ pads.append([pad_i - pad_i // 2, pad_i // 2])
- # TODO(agladyshev): for now we support only 2D kernel
- # (top, left, bottom, right)
- flatten_pads = [pads[0][0], pads[1][0], pads[0][1], pads[1][1]]
- pads = tuple(flatten_pads)
+ pads = tuple([val for pair in zip(*pads) for val in pair])
- return relax.op.nn.max_pool2d(data, kernel_shape, strides, pads,
dilations, ceil_mode)
+ op = getattr(relax.op.nn, cls.name + str(len(kernel_shape)) + "d")
+ return op(data, kernel_shape, strides, pads, dilations, ceil_mode)
@classmethod
def _get_input_spatial_shape(cls, tensor):
@@ -1496,6 +1521,18 @@ class MaxPool(OnnxOpConverter):
return _np.array([int(d) for d in tensor.struct_info.shape],
dtype="int64")[2:]
+class MaxPool(Pool):
+ """Converts an onnx MaxPool node into an equivalent Relax expression."""
+
+ name = "max_pool"
+
+
+class AveragePool(Pool):
+ """Converts an onnx MaxPool node into an equivalent Relax expression."""
+
+ name = "avg_pool"
+
+
class GlobalAveragePool(OnnxOpConverter):
"""Converts an onnx GlobalAveragePool node into an equivalent Relax
expression."""
@@ -1922,9 +1959,10 @@ def _get_convert_map():
"Split": Split,
"Tile": Tile,
"BatchNormalization": BatchNormalization,
+ "MaxPool": MaxPool,
+ "AveragePool": AveragePool,
"GlobalAveragePool": GlobalAveragePool,
"Flatten": Flatten,
- "MaxPool": MaxPool,
"Identity": Identity,
"Resize": Resize,
"Einsum": Einsum,
diff --git a/python/tvm/relax/op/_op_gradient.py
b/python/tvm/relax/op/_op_gradient.py
index 1b0ebfd5e4..6878f97331 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -1279,6 +1279,7 @@ out_layout)`
orig_call.attrs.padding,
orig_call.attrs.dilation,
orig_call.attrs.ceil_mode,
+ orig_call.attrs.count_include_pad,
orig_call.attrs.layout,
orig_call.attrs.out_layout,
)
@@ -1310,6 +1311,7 @@ out_layout)`
orig_call.attrs.padding,
orig_call.attrs.dilation,
orig_call.attrs.ceil_mode,
+ orig_call.attrs.count_include_pad,
orig_call.attrs.layout,
orig_call.attrs.out_layout,
)
diff --git a/python/tvm/relax/op/grad/grad.py b/python/tvm/relax/op/grad/grad.py
index 2218db2232..304ad9cc2f 100644
--- a/python/tvm/relax/op/grad/grad.py
+++ b/python/tvm/relax/op/grad/grad.py
@@ -130,6 +130,7 @@ def max_pool2d_backward(
padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
dilation: Tuple[int, int] = (1, 1),
ceil_mode: bool = False,
+ count_include_pad: bool = False,
layout: str = "NCHW",
out_layout: Optional[str] = None,
) -> Expr:
@@ -147,7 +148,16 @@ def max_pool2d_backward(
The gradient w.r.t. data.
"""
return _ffi_api.max_pool2d_backward( # type: ignore
- output_grad, data, pool_size, strides, padding, dilation, ceil_mode,
layout, out_layout
+ output_grad,
+ data,
+ pool_size,
+ strides,
+ padding,
+ dilation,
+ ceil_mode,
+ count_include_pad,
+ layout,
+ out_layout,
)
@@ -159,6 +169,7 @@ def avg_pool2d_backward(
padding: Tuple[int, int, int, int] = (0, 0, 0, 0),
dilation: Tuple[int, int] = (1, 1),
ceil_mode: bool = False,
+ count_include_pad: bool = False,
layout: str = "NCHW",
out_layout: Optional[str] = None,
) -> Expr:
@@ -176,7 +187,16 @@ def avg_pool2d_backward(
The gradient w.r.t. data.
"""
return _ffi_api.avg_pool2d_backward( # type: ignore
- output_grad, data, pool_size, strides, padding, dilation, ceil_mode,
layout, out_layout
+ output_grad,
+ data,
+ pool_size,
+ strides,
+ padding,
+ dilation,
+ ceil_mode,
+ count_include_pad,
+ layout,
+ out_layout,
)
diff --git a/python/tvm/relax/op/nn/__init__.py
b/python/tvm/relax/op/nn/__init__.py
index d90b207314..cb90a86883 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -19,7 +19,9 @@ from .nn import (
adaptive_avg_pool2d,
attention,
attention_var_len,
+ avg_pool1d,
avg_pool2d,
+ avg_pool3d,
batch_norm,
conv1d,
conv1d_transpose,
@@ -34,7 +36,9 @@ from .nn import (
layer_norm,
leakyrelu,
log_softmax,
+ max_pool1d,
max_pool2d,
+ max_pool3d,
nll_loss,
pad,
relu,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 151c43af55..26ba894e84 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -542,6 +542,87 @@ def pad(data, pad_width, pad_value=0, pad_mode="constant"):
return _ffi_api.pad(data, pad_width, pad_value, pad_mode)
+def max_pool1d(
+ data: Expr,
+ pool_size: Union[int, Tuple[int, int]] = (1,),
+ strides: Union[int, Tuple[int, int]] = (1,),
+ padding: Union[int, Tuple[int, ...]] = (0, 0),
+ dilation: Union[int, Tuple[int, int]] = (1,),
+ ceil_mode: bool = False,
+ count_include_pad: bool = False,
+ layout: str = "NCW",
+ out_layout: Optional[str] = None,
+) -> Expr:
+ r"""1D maximum pooling operator.
+
+ This operator takes data as input and does 1D max value calculation
+ with in pool_size sized window by striding defined by stride.
+
+ IIn the default case, where the data_layout is `NCW`
+ a data Tensor with shape `(batch_size, channels, width)`,
+ to produce an output Tensor.
+
+ The ceil_mode is used to take ceil or floor while computing out shape.
+ count_include_pad indicates including or excluding padded input values in
computation.
+ This operator accepts data layout specification.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ pool_size : Union[int, Tuple[int, int]]
+ The size of window for pooling. It is required to have length either 1.
+
+ strides : Union[int, Tuple[int, int]]
+ The strides of pooling. It is required to have length either 1.
+
+ padding : Union[int, Tuple[int, ...]]
+ The padding for pooling. It is required to have length either 1 or 2.
+
+ dilation : Union[int, Tuple[int, int]]
+ The dilation of pooling. It is required to have length either 1.
+
+ ceil_mode : bool
+ A boolean indicating if use ceil or floor to compute the output shape.
+ By using ceil, every element in the input tensor will be covered by a
sliding window.
+
+ count_include_pad : bool, optional
+ To include padding to compute the average.
+
+ layout : str
+ Layout of the input.
+
+ out_layout : Optional[str]
+ Layout of the output. If not specified, it is the same as data_layout
+
+ Returns
+ -------
+ result : Expr
+ The computed result.
+ """
+ if isinstance(pool_size, int):
+ pool_size = (pool_size,)
+ if isinstance(strides, int):
+ strides = (strides,)
+ if isinstance(dilation, int):
+ dilation = (dilation,)
+ if isinstance(padding, int):
+ padding = (padding, padding)
+
+ return _ffi_api.max_pool1d( # type: ignore
+ data,
+ pool_size,
+ strides,
+ padding,
+ dilation,
+ ceil_mode,
+ count_include_pad,
+ layout,
+ out_layout,
+ )
+
+
def max_pool2d(
data: Expr,
pool_size: Union[int, Tuple[int, int]] = (1, 1),
@@ -549,6 +630,7 @@ def max_pool2d(
padding: Union[int, Tuple[int, ...]] = (0, 0),
dilation: Union[int, Tuple[int, int]] = (1, 1),
ceil_mode: bool = False,
+ count_include_pad: bool = False,
layout: str = "NCHW",
out_layout: Optional[str] = None,
) -> Expr:
@@ -593,6 +675,9 @@ def max_pool2d(
A boolean indicating if use ceil or floor to compute the output shape.
By using ceil, every element in the input tensor will be covered by a
sliding window.
+ count_include_pad : bool, optional
+ To include padding to compute the average.
+
layout : str
Layout of the input.
@@ -614,7 +699,177 @@ def max_pool2d(
padding = (padding, padding, padding, padding)
return _ffi_api.max_pool2d( # type: ignore
- data, pool_size, strides, padding, dilation, ceil_mode, layout,
out_layout
+ data,
+ pool_size,
+ strides,
+ padding,
+ dilation,
+ ceil_mode,
+ count_include_pad,
+ layout,
+ out_layout,
+ )
+
+
+def max_pool3d(
+ data: Expr,
+ pool_size: Union[int, Tuple[int, int]] = (1, 1, 1),
+ strides: Union[int, Tuple[int, int]] = (1, 1, 1),
+ padding: Union[int, Tuple[int, ...]] = (0, 0, 0),
+ dilation: Union[int, Tuple[int, int]] = (1, 1, 1),
+ ceil_mode: bool = False,
+ count_include_pad: bool = False,
+ layout: str = "NCDHW",
+ out_layout: Optional[str] = None,
+) -> Expr:
+ r"""3D maximum pooling operator.
+
+ This operator takes data as input and does 3D max value calculation
+ with in pool_size sized window by striding defined by stride.
+
+
+ In the default case, where the data_layout is `NCDHW`
+ a data Tensor with shape `(batch_size, channels, depth, height, width)`,
+ to produce an output Tensor.
+
+ The ceil_mode is used to take ceil or floor while computing out shape.
+ count_include_pad indicates including or excluding padded input values in
computation.
+ This operator accepts data layout specification.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ pool_size : Union[int, Tuple[int, int]]
+ The size of window for pooling. It is required to have length either 1
or 3.
+
+ strides : Union[int, Tuple[int, int]]
+ The strides of pooling. It is required to have length either 1 or 3.
+
+ padding : Union[int, Tuple[int, ...]]
+ The padding for pooling. It is required to have length either 1, 3 or
6.
+
+ dilation : Union[int, Tuple[int, int]]
+ The dilation of pooling. It is required to have length either 1 or 3.
+
+ ceil_mode : bool
+ A boolean indicating if use ceil or floor to compute the output shape.
+ By using ceil, every element in the input tensor will be covered by a
sliding window.
+
+ count_include_pad : bool, optional
+ To include padding to compute the average.
+
+ layout : str
+ Layout of the input.
+
+ out_layout : Optional[str]
+ Layout of the output. If not specified, it is the same as data_layout
+
+ Returns
+ -------
+ result : Expr
+ The computed result.
+ """
+ if isinstance(pool_size, int):
+ pool_size = (pool_size, pool_size, pool_size)
+ if isinstance(strides, int):
+ strides = (strides, strides, strides)
+ if isinstance(dilation, int):
+ dilation = (dilation, dilation, dilation)
+ if isinstance(padding, int):
+ padding = (padding, padding, padding, padding, padding, padding)
+
+ return _ffi_api.max_pool3d( # type: ignore
+ data,
+ pool_size,
+ strides,
+ padding,
+ dilation,
+ ceil_mode,
+ count_include_pad,
+ layout,
+ out_layout,
+ )
+
+
+def avg_pool1d(
+ data: Expr,
+ pool_size: Union[int, Tuple[int, int]] = (1,),
+ strides: Union[int, Tuple[int, int]] = (1,),
+ padding: Union[int, Tuple[int, ...]] = (0, 0),
+ dilation: Union[int, Tuple[int, int]] = (1,),
+ ceil_mode: bool = False,
+ count_include_pad: bool = False,
+ layout: str = "NCW",
+ out_layout: Optional[str] = None,
+) -> Expr:
+ r"""1D average pooling operator.
+
+ This operator takes data as input and does 1D average value calculation
+ with in pool_size sized window by striding defined by stride
+
+ In the default case, where the data_layout is `NCW`
+ a data Tensor with shape `(batch_size, channels, width)`,
+ to produce an output Tensor.
+
+ The ceil_mode is used to take ceil or floor while computing out shape.
+ count_include_pad indicates including or excluding padded input values in
computation.
+ This operator accepts data layout specification.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ pool_size : Union[int, Tuple[int]]
+ The size of window for pooling. It is required to have length is 1.
+
+ strides : Union[int, Tuple[int]]
+ The strides of pooling. It is required to have length is 1.
+
+ padding : Union[int, Tuple[int, int]]
+ The padding for pooling. It is required to have length either 1 or 2.
+
+ dilation : Union[int, Tuple[int]]
+ The dilation of pooling. It is required to have length is 1.
+
+ ceil_mode : bool
+ A boolean indicating if use ceil or floor to compute the output shape.
+ By using ceil, every element in the input tensor will be covered by a
sliding window.
+
+ count_include_pad : bool, optional
+ To include padding to compute the average.
+
+ layout : str
+ Layout of the input.
+
+ out_layout : Optional[str]
+ Layout of the output. If not specified, it is the same as data_layout
+
+ Returns
+ -------
+ result : Expr
+ The computed result.
+ """
+ if isinstance(pool_size, int):
+ pool_size = (pool_size,)
+ if isinstance(strides, int):
+ strides = (strides,)
+ if isinstance(dilation, int):
+ dilation = (dilation,)
+ if isinstance(padding, int):
+ padding = (padding, padding)
+ return _ffi_api.avg_pool1d( # type: ignore
+ data,
+ pool_size,
+ strides,
+ padding,
+ dilation,
+ ceil_mode,
+ count_include_pad,
+ layout,
+ out_layout,
)
@@ -625,6 +880,7 @@ def avg_pool2d(
padding: Union[int, Tuple[int, ...]] = (0, 0),
dilation: Union[int, Tuple[int, int]] = (1, 1),
ceil_mode: bool = False,
+ count_include_pad: bool = False,
layout: str = "NCHW",
out_layout: Optional[str] = None,
) -> Expr:
@@ -670,6 +926,9 @@ def avg_pool2d(
A boolean indicating if use ceil or floor to compute the output shape.
By using ceil, every element in the input tensor will be covered by a
sliding window.
+ count_include_pad : bool, optional
+ To include padding to compute the average.
+
layout : str
Layout of the input.
@@ -689,9 +948,98 @@ def avg_pool2d(
dilation = (dilation, dilation)
if isinstance(padding, int):
padding = (padding, padding, padding, padding)
-
return _ffi_api.avg_pool2d( # type: ignore
- data, pool_size, strides, padding, dilation, ceil_mode, layout,
out_layout
+ data,
+ pool_size,
+ strides,
+ padding,
+ dilation,
+ ceil_mode,
+ count_include_pad,
+ layout,
+ out_layout,
+ )
+
+
+def avg_pool3d(
+ data: Expr,
+ pool_size: Union[int, Tuple[int, int]] = (1, 1, 1),
+ strides: Union[int, Tuple[int, int]] = (1, 1, 1),
+ padding: Union[int, Tuple[int, ...]] = (0, 0, 0),
+ dilation: Union[int, Tuple[int, int]] = (1, 1, 1),
+ ceil_mode: bool = False,
+ count_include_pad: bool = False,
+ layout: str = "NCDHW",
+ out_layout: Optional[str] = None,
+) -> Expr:
+ r"""2D average pooling operator.
+
+ This operator takes data as input and does 3D average value calculation
+ with in pool_size sized window by striding defined by stride
+
+
+ In the default case, where the data_layout is `NCDHW`
+ a data Tensor with shape `(batch_size, channels, depth, height, width)`,
+ to produce an output Tensor.
+
+ The ceil_mode is used to take ceil or floor while computing out shape.
+ count_include_pad indicates including or excluding padded input values in
computation.
+ This operator accepts data layout specification.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ pool_size : Union[int, Tuple[int, int, int]]
+ The size of window for pooling. It is required to have length either 1
or 3.
+
+ strides : Union[int, Tuple[int, int, int]]
+ The strides of pooling. It is required to have length either 1 or 3.
+
+ padding : Union[int, Tuple[int, ...]]
+ The padding for pooling. It is required to have length either 1, 3 or
6.
+
+ dilation : Union[int, Tuple[int, int, int]]
+ The dilation of pooling. It is required to have length either 1 or 3.
+
+ ceil_mode : bool
+ A boolean indicating if use ceil or floor to compute the output shape.
+ By using ceil, every element in the input tensor will be covered by a
sliding window.
+
+ count_include_pad : bool, optional
+ To include padding to compute the average.
+
+ layout : str
+ Layout of the input.
+
+ out_layout : Optional[str]
+ Layout of the output. If not specified, it is the same as data_layout
+
+ Returns
+ -------
+ result : Expr
+ The computed result.
+ """
+ if isinstance(pool_size, int):
+ pool_size = (pool_size, pool_size, pool_size)
+ if isinstance(strides, int):
+ strides = (strides, strides, strides)
+ if isinstance(dilation, int):
+ dilation = (dilation, dilation, dilation)
+ if isinstance(padding, int):
+ padding = (padding, padding, padding, padding, padding, padding)
+
+ return _ffi_api.avg_pool3d( # type: ignore
+ data,
+ pool_size,
+ strides,
+ padding,
+ dilation,
+ ceil_mode,
+ count_include_pad,
+ layout,
+ out_layout,
)
diff --git a/python/tvm/relax/transform/legalize_ops/grad.py
b/python/tvm/relax/transform/legalize_ops/grad.py
index 1d527bea6a..4fde2a25c3 100644
--- a/python/tvm/relax/transform/legalize_ops/grad.py
+++ b/python/tvm/relax/transform/legalize_ops/grad.py
@@ -125,6 +125,7 @@ def _grad_max_pool2d_backward(bb: BlockBuilder, call: Call)
-> Expr:
padding=call.attrs.padding,
pool_type="max",
ceil_mode=call.attrs.ceil_mode,
+ count_include_pad=call.attrs.count_include_pad,
layout=call.attrs.layout,
primfunc_name_hint="max_pool2d_backward",
)
@@ -144,6 +145,7 @@ def _grad_avg_pool2d_backward(bb: BlockBuilder, call: Call)
-> Expr:
padding=call.attrs.padding,
pool_type="avg",
ceil_mode=call.attrs.ceil_mode,
+ count_include_pad=call.attrs.count_include_pad,
layout=call.attrs.layout,
primfunc_name_hint="avg_pool2d_backward",
)
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index f80d28099c..8f5407ff09 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -241,6 +241,29 @@ def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.nn.max_pool1d")
+def _nn_max_pool1d(bb: BlockBuilder, call: Call) -> Expr:
+ if call.attrs.out_layout != call.attrs.layout:
+ logging.info(
+ "TOPI max_pool1d does not support different input-output "
+ "layouts, and thus cannot be legalized by TOPI"
+ )
+ return call
+
+ return bb.call_te(
+ topi.nn.pool1d,
+ call.args[0],
+ kernel=call.attrs.pool_size,
+ stride=call.attrs.strides,
+ dilation=call.attrs.dilation,
+ padding=call.attrs.padding,
+ pool_type="max",
+ ceil_mode=call.attrs.ceil_mode,
+ layout=call.attrs.layout,
+ primfunc_name_hint="max_pool1d",
+ )
+
+
@register_legalize("relax.nn.max_pool2d")
def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.layout:
@@ -264,6 +287,53 @@ def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr:
)
+@register_legalize("relax.nn.max_pool3d")
+def _nn_max_pool3d(bb: BlockBuilder, call: Call) -> Expr:
+ if call.attrs.out_layout != call.attrs.layout:
+ logging.info(
+ "TOPI max_pool3d does not support different input-output "
+ "layouts, and thus cannot be legalized by TOPI"
+ )
+ return call
+
+ return bb.call_te(
+ topi.nn.pool3d,
+ call.args[0],
+ kernel=call.attrs.pool_size,
+ stride=call.attrs.strides,
+ dilation=call.attrs.dilation,
+ padding=call.attrs.padding,
+ pool_type="max",
+ ceil_mode=call.attrs.ceil_mode,
+ layout=call.attrs.layout,
+ primfunc_name_hint="max_pool3d",
+ )
+
+
+@register_legalize("relax.nn.avg_pool1d")
+def _nn_avg_pool1d(bb: BlockBuilder, call: Call) -> Expr:
+ if call.attrs.out_layout != call.attrs.layout:
+ logging.info(
+ "TOPI avg_pool1d does not support different input-output "
+ "layouts, and thus cannot be legalized by TOPI"
+ )
+ return call
+
+ return bb.call_te(
+ topi.nn.pool1d,
+ call.args[0],
+ kernel=call.attrs.pool_size,
+ stride=call.attrs.strides,
+ dilation=call.attrs.dilation,
+ padding=call.attrs.padding,
+ pool_type="avg",
+ ceil_mode=call.attrs.ceil_mode,
+ layout=call.attrs.layout,
+ count_include_pad=call.attrs.count_include_pad,
+ primfunc_name_hint="avg_pool1d",
+ )
+
+
@register_legalize("relax.nn.avg_pool2d")
def _nn_avg_pool2d(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.layout:
@@ -283,10 +353,35 @@ def _nn_avg_pool2d(bb: BlockBuilder, call: Call) -> Expr:
pool_type="avg",
ceil_mode=call.attrs.ceil_mode,
layout=call.attrs.layout,
+ count_include_pad=call.attrs.count_include_pad,
primfunc_name_hint="avg_pool2d",
)
+@register_legalize("relax.nn.avg_pool3d")
+def _nn_avg_pool3d(bb: BlockBuilder, call: Call) -> Expr:
+ if call.attrs.out_layout != call.attrs.layout:
+ logging.info(
+ "TOPI avg_pool3d does not support different input-output "
+ "layouts, and thus cannot be legalized by TOPI"
+ )
+ return call
+
+ return bb.call_te(
+ topi.nn.pool3d,
+ call.args[0],
+ kernel=call.attrs.pool_size,
+ stride=call.attrs.strides,
+ dilation=call.attrs.dilation,
+ padding=call.attrs.padding,
+ pool_type="avg",
+ ceil_mode=call.attrs.ceil_mode,
+ layout=call.attrs.layout,
+ count_include_pad=call.attrs.count_include_pad,
+ primfunc_name_hint="avg_pool3d",
+ )
+
+
@register_legalize("relax.nn.adaptive_avg_pool2d")
def _nn_adaptive_avg_pool2d(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.layout:
diff --git a/python/tvm/topi/nn/pooling.py b/python/tvm/topi/nn/pooling.py
index b12c492ed8..a45480f12e 100644
--- a/python/tvm/topi/nn/pooling.py
+++ b/python/tvm/topi/nn/pooling.py
@@ -65,8 +65,8 @@ def pool_grad(
padding,
pool_type,
ceil_mode=False,
- layout="NCHW",
count_include_pad=True,
+ layout="NCHW",
):
"""Gradient of pooling on height and width dimension of data.
It decides the height and width dimension according to the layout
string,
@@ -99,6 +99,9 @@ def pool_grad(
ceil_mode : bool
Whether to use ceil when calculating output size.
+ count_include_pad: bool
+ Whether include padding in the calculation when pool_type is 'avg'
+
layout: string
Layout of the input data.
The layout is supposed to be composed of upper cases, lower cases and
numbers,
@@ -108,8 +111,6 @@ def pool_grad(
[batch_size, channel, height, width, channel_block],
in which channel_block=16 is a split of dimension channel.
- count_include_pad: bool
- Whether include padding in the calculation when pool_type is 'avg'
Returns
-------
diff --git a/src/relax/op/nn/pooling.cc b/src/relax/op/nn/pooling.cc
index 6c81f5310a..865d419bca 100644
--- a/src/relax/op/nn/pooling.cc
+++ b/src/relax/op/nn/pooling.cc
@@ -25,12 +25,116 @@
namespace tvm {
namespace relax {
-/* relax.nn.max_pool2d and relax.nn.avg_pool2d */
+/* relax.nn.max_pool1d */
+TVM_REGISTER_NODE_TYPE(Pool1DAttrs);
+
+Expr MakePool1d(String op_name, Expr data, Array<IntImm> pool_size,
Array<IntImm> strides,
+ Array<IntImm> padding, Array<IntImm> dilation, bool ceil_mode,
+ bool count_include_pad, String layout, Optional<String>
out_layout) {
+ padding = GetCompletePadding1D(std::move(padding));
+
+ CHECK_EQ(pool_size.size(), 1)
+ << "The input pool_size length is expected to be 1. However, the given
pool_size is "
+ << pool_size;
+ CHECK_EQ(strides.size(), 1)
+ << "The input strides length is expected to be 1. However, the given
strides is " << strides;
+ CHECK_EQ(dilation.size(), 1)
+ << "The input dilation length is expected to be 1. However, the given
dilation is "
+ << dilation;
+
+ auto attrs = make_object<Pool1DAttrs>();
+ attrs->pool_size = ConvertIntImmToInt64(pool_size);
+ attrs->strides = ConvertIntImmToInt64(strides);
+ attrs->padding = ConvertIntImmToInt64(padding);
+ attrs->dilation = ConvertIntImmToInt64(dilation);
+ attrs->ceil_mode = ceil_mode;
+ attrs->count_include_pad = count_include_pad;
+ attrs->layout = layout;
+ attrs->out_layout = out_layout.value_or(layout);
+ const Op& op = Op::Get(op_name);
+ return Call(op, {std::move(data)}, Attrs(attrs), {});
+}
+
+Expr max_pool1d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
+ Array<IntImm> dilation, bool ceil_mode, bool
count_include_pad, String layout,
+ Optional<String> out_layout) {
+ return MakePool1d("relax.nn.max_pool1d", data, pool_size, strides, padding,
dilation, ceil_mode,
+ count_include_pad, layout, out_layout);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.max_pool1d").set_body_typed(max_pool1d);
+
+StructInfo InferStructInfoPool1D(const Call& call, const BlockBuilder& ctx) {
+ TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+
+ const auto* attrs = call->attrs.as<Pool1DAttrs>();
+ auto [data_layout, data2NCW] = CheckTensorLayout(call, ctx, attrs->layout,
+ /*tgt_layout=*/"NCW",
+ /*tensor_name=*/"data");
+ auto [out_layout, out2NCW] = CheckTensorLayout(call, ctx, attrs->out_layout,
+ /*tgt_layout=*/"NCW",
+ /*tensor_name=*/"output");
+
+ Optional<ShapeExpr> data_shape =
+ CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+ if (!data_shape.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(),
data_sinfo->vdevice);
+ }
+
+ Array<PrimExpr> data_NCW_shape =
data2NCW.ForwardShape(data_shape.value()->values);
+
+ PrimExpr input_w = data_NCW_shape[2];
+ PrimExpr kernel_w = attrs->pool_size[0];
+ PrimExpr padding_w = attrs->padding[0] + attrs->padding[1];
+
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ std::vector<PrimExpr> out_NCW_shape;
+ out_NCW_shape.resize(3);
+ out_NCW_shape[0] = data_NCW_shape[0];
+ out_NCW_shape[1] = data_NCW_shape[1];
+
+ PrimExpr numerator_w = input_w + padding_w - attrs->dilation[0] * (kernel_w
- 1) - 1;
+ if (attrs->ceil_mode) {
+ numerator_w += attrs->strides[1] - 1;
+ }
+ out_NCW_shape[2] = analyzer->Simplify(floordiv(numerator_w,
attrs->strides[0]) + 1);
+
+ Array<PrimExpr> out_shape = out2NCW.BackwardShape(out_NCW_shape);
+ return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype,
data_sinfo->vdevice);
+}
+
+InferLayoutOutput InferLayoutPool1d(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK_EQ(tensor_sinfo->ndim, 3) << "Unsupported initial layout";
+ const auto* attrs = call->attrs.as<Pool1DAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ ObjectPtr<Pool1DAttrs> new_attrs = make_object<Pool1DAttrs>(*attrs);
+ new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(3),
layout->layout).name();
+ new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(3),
layout->layout).name();
+ return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
+}
+
+TVM_REGISTER_OP("relax.nn.max_pool1d")
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor")
+ .set_attrs_type<Pool1DAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool1D)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool1d)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+/* relax.nn.max_pool2d */
TVM_REGISTER_NODE_TYPE(Pool2DAttrs);
Expr MakePool2d(String op_name, Expr data, Array<IntImm> pool_size,
Array<IntImm> strides,
- Array<IntImm> padding, Array<IntImm> dilation, bool ceil_mode,
String layout,
- Optional<String> out_layout) {
+ Array<IntImm> padding, Array<IntImm> dilation, bool ceil_mode,
+ bool count_include_pad, String layout, Optional<String>
out_layout) {
padding = GetCompletePadding2D(std::move(padding));
if (pool_size.size() == 1) {
pool_size.push_back(pool_size[0]);
@@ -57,6 +161,7 @@ Expr MakePool2d(String op_name, Expr data, Array<IntImm>
pool_size, Array<IntImm
attrs->padding = ConvertIntImmToInt64(padding);
attrs->dilation = ConvertIntImmToInt64(dilation);
attrs->ceil_mode = ceil_mode;
+ attrs->count_include_pad = count_include_pad;
attrs->layout = layout;
attrs->out_layout = out_layout.value_or(layout);
const Op& op = Op::Get(op_name);
@@ -64,10 +169,10 @@ Expr MakePool2d(String op_name, Expr data, Array<IntImm>
pool_size, Array<IntImm
}
Expr max_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
- Array<IntImm> dilation, bool ceil_mode, String layout,
+ Array<IntImm> dilation, bool ceil_mode, bool
count_include_pad, String layout,
Optional<String> out_layout) {
return MakePool2d("relax.nn.max_pool2d", data, pool_size, strides, padding,
dilation, ceil_mode,
- layout, out_layout);
+ count_include_pad, layout, out_layout);
}
TVM_REGISTER_GLOBAL("relax.op.nn.max_pool2d").set_body_typed(max_pool2d);
@@ -143,11 +248,159 @@ TVM_REGISTER_OP("relax.nn.max_pool2d")
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.nn.max_pool3d */
+TVM_REGISTER_NODE_TYPE(Pool3DAttrs);
+
+Expr MakePool3d(String op_name, Expr data, Array<IntImm> pool_size,
Array<IntImm> strides,
+ Array<IntImm> padding, Array<IntImm> dilation, bool ceil_mode,
+ bool count_include_pad, String layout, Optional<String>
out_layout) {
+ padding = GetCompletePadding3D(std::move(padding));
+ if (pool_size.size() == 1) {
+ pool_size.push_back(pool_size[0]);
+ pool_size.push_back(pool_size[0]);
+ }
+ if (strides.size() == 1) {
+ strides.push_back(strides[0]);
+ strides.push_back(strides[0]);
+ }
+ if (dilation.size() == 1) {
+ dilation.push_back(dilation[0]);
+ dilation.push_back(dilation[0]);
+ }
+
+ CHECK_EQ(pool_size.size(), 3)
+ << "The input pool_size length is expected to be 3. However, the given
pool_size is "
+ << pool_size;
+ CHECK_EQ(strides.size(), 3)
+ << "The input strides length is expected to be 3. However, the given
strides is " << strides;
+ CHECK_EQ(dilation.size(), 3)
+ << "The input dilation length is expected to be 3. However, the given
dilation is "
+ << dilation;
+
+ auto attrs = make_object<Pool3DAttrs>();
+ attrs->pool_size = ConvertIntImmToInt64(pool_size);
+ attrs->strides = ConvertIntImmToInt64(strides);
+ attrs->padding = ConvertIntImmToInt64(padding);
+ attrs->dilation = ConvertIntImmToInt64(dilation);
+ attrs->ceil_mode = ceil_mode;
+ attrs->count_include_pad = count_include_pad;
+ attrs->layout = layout;
+ attrs->out_layout = out_layout.value_or(layout);
+ const Op& op = Op::Get(op_name);
+ return Call(op, {std::move(data)}, Attrs(attrs), {});
+}
+
+Expr max_pool3d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
+ Array<IntImm> dilation, bool ceil_mode, bool
count_include_pad, String layout,
+ Optional<String> out_layout) {
+ return MakePool3d("relax.nn.max_pool3d", data, pool_size, strides, padding,
dilation, ceil_mode,
+ count_include_pad, layout, out_layout);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.max_pool3d").set_body_typed(max_pool3d);
+
+StructInfo InferStructInfoPool3D(const Call& call, const BlockBuilder& ctx) {
+ TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+
+ const auto* attrs = call->attrs.as<Pool3DAttrs>();
+ auto [data_layout, data2NCDHW] = CheckTensorLayout(call, ctx, attrs->layout,
+ /*tgt_layout=*/"NCDHW",
+ /*tensor_name=*/"data");
+ auto [out_layout, out2NCDHW] = CheckTensorLayout(call, ctx,
attrs->out_layout,
+ /*tgt_layout=*/"NCDHW",
+ /*tensor_name=*/"output");
+
+ Optional<ShapeExpr> data_shape =
+ CheckNdimPerLayoutAndGetShape(call, ctx, data_sinfo, data_layout);
+ if (!data_shape.defined()) {
+ return TensorStructInfo(data_sinfo->dtype, out_layout.ndim(),
data_sinfo->vdevice);
+ }
+
+ Array<PrimExpr> data_NCDHW_shape =
data2NCDHW.ForwardShape(data_shape.value()->values);
+
+ PrimExpr input_d = data_NCDHW_shape[2];
+ PrimExpr input_h = data_NCDHW_shape[3];
+ PrimExpr input_w = data_NCDHW_shape[4];
+ PrimExpr kernel_d = attrs->pool_size[0];
+ PrimExpr kernel_h = attrs->pool_size[1];
+ PrimExpr kernel_w = attrs->pool_size[2];
+ PrimExpr padding_d = attrs->padding[0] + attrs->padding[3];
+ PrimExpr padding_h = attrs->padding[1] + attrs->padding[4];
+ PrimExpr padding_w = attrs->padding[2] + attrs->padding[5];
+
+ arith::Analyzer* analyzer = ctx->GetAnalyzer();
+ std::vector<PrimExpr> out_NCDHW_shape;
+ out_NCDHW_shape.resize(5);
+ out_NCDHW_shape[0] = data_NCDHW_shape[0];
+ out_NCDHW_shape[1] = data_NCDHW_shape[1];
+
+ PrimExpr numerator_d = input_d + padding_d - attrs->dilation[0] * (kernel_d
- 1) - 1;
+ PrimExpr numerator_h = input_h + padding_h - attrs->dilation[1] * (kernel_h
- 1) - 1;
+ PrimExpr numerator_w = input_w + padding_w - attrs->dilation[2] * (kernel_w
- 1) - 1;
+ if (attrs->ceil_mode) {
+ numerator_d += attrs->strides[0] - 1;
+ numerator_h += attrs->strides[1] - 1;
+ numerator_w += attrs->strides[2] - 1;
+ }
+ out_NCDHW_shape[2] = analyzer->Simplify(floordiv(numerator_d,
attrs->strides[0]) + 1);
+ out_NCDHW_shape[3] = analyzer->Simplify(floordiv(numerator_h,
attrs->strides[1]) + 1);
+ out_NCDHW_shape[4] = analyzer->Simplify(floordiv(numerator_w,
attrs->strides[2]) + 1);
+
+ Array<PrimExpr> out_shape = out2NCDHW.BackwardShape(out_NCDHW_shape);
+ return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype,
data_sinfo->vdevice);
+}
+
+InferLayoutOutput InferLayoutPool3d(const Call& call,
+ const Map<String, Array<String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+ const auto* tensor_sinfo = GetStructInfoAs<TensorStructInfoNode>(call);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK_EQ(tensor_sinfo->ndim, 5) << "Unsupported initial layout";
+ const auto* attrs = call->attrs.as<Pool3DAttrs>();
+ ICHECK(attrs) << "Invalid Call";
+
+ LayoutDecision layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ ObjectPtr<Pool3DAttrs> new_attrs = make_object<Pool3DAttrs>(*attrs);
+ new_attrs->layout = TransposeLike(attrs->layout, InitialLayout(5),
layout->layout).name();
+ new_attrs->out_layout = TransposeLike(attrs->out_layout, InitialLayout(5),
layout->layout).name();
+ return InferLayoutOutput({layout}, {layout}, Attrs(new_attrs));
+}
+
+TVM_REGISTER_OP("relax.nn.max_pool3d")
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor")
+ .set_attrs_type<Pool3DAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool3D)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool3d)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+/* relax.nn.avg_pool1d */
+Expr avg_pool1d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
+ Array<IntImm> dilation, bool ceil_mode, bool
count_include_pad, String layout,
+ Optional<String> out_layout) {
+ return MakePool1d("relax.nn.avg_pool1d", data, pool_size, strides, padding,
dilation, ceil_mode,
+ count_include_pad, layout, out_layout);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool1d").set_body_typed(avg_pool1d);
+
+TVM_REGISTER_OP("relax.nn.avg_pool1d")
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor")
+ .set_attrs_type<Pool1DAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool1D)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool1d)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
+/* relax.nn.avg_pool2d */
Expr avg_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
- Array<IntImm> dilation, bool ceil_mode, String layout,
+ Array<IntImm> dilation, bool ceil_mode, bool
count_include_pad, String layout,
Optional<String> out_layout) {
return MakePool2d("relax.nn.avg_pool2d", data, pool_size, strides, padding,
dilation, ceil_mode,
- layout, out_layout);
+ count_include_pad, layout, out_layout);
}
TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool2d").set_body_typed(avg_pool2d);
@@ -161,6 +414,25 @@ TVM_REGISTER_OP("relax.nn.avg_pool2d")
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.nn.avg_pool3d */
+Expr avg_pool3d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
+ Array<IntImm> dilation, bool ceil_mode, bool
count_include_pad, String layout,
+ Optional<String> out_layout) {
+ return MakePool3d("relax.nn.avg_pool3d", data, pool_size, strides, padding,
dilation, ceil_mode,
+ count_include_pad, layout, out_layout);
+}
+
+TVM_REGISTER_GLOBAL("relax.op.nn.avg_pool3d").set_body_typed(avg_pool3d);
+
+TVM_REGISTER_OP("relax.nn.avg_pool3d")
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor")
+ .set_attrs_type<Pool3DAttrs>()
+ .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPool3D)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout", InferLayoutPool3d)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
/* relax.nn.adaptive_avg_pool2d */
TVM_REGISTER_NODE_TYPE(AdaptivePool2DAttrs);
diff --git a/src/relax/op/nn/pooling.h b/src/relax/op/nn/pooling.h
index 63d2e76772..7fd66f2b44 100644
--- a/src/relax/op/nn/pooling.h
+++ b/src/relax/op/nn/pooling.h
@@ -34,11 +34,13 @@ namespace relax {
/*! \brief 2D maximum pooling operator. */
Expr max_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
- Array<IntImm> dilation, bool ceil_mode, String layout,
Optional<String> out_layout);
+ Array<IntImm> dilation, bool ceil_mode, bool
count_include_pad, String layout,
+ Optional<String> out_layout);
/*! \brief 2D average pooling operator. */
Expr avg_pool2d(Expr data, Array<IntImm> pool_size, Array<IntImm> strides,
Array<IntImm> padding,
- Array<IntImm> dilation, bool ceil_mode, String layout,
Optional<String> out_layout);
+ Array<IntImm> dilation, bool ceil_mode, bool
count_include_pad, String layout,
+ Optional<String> out_layout);
/*! \brief 2D adaptive average pooling operator. */
Expr adaptive_avg_pool2d(Expr data, Optional<Array<IntImm>> output_size,
String layout,
diff --git a/src/relax/op/tensor/grad.cc b/src/relax/op/tensor/grad.cc
index 6f30684460..70114093e3 100644
--- a/src/relax/op/tensor/grad.cc
+++ b/src/relax/op/tensor/grad.cc
@@ -130,13 +130,15 @@ TVM_REGISTER_OP("relax.grad.nll_loss_backward")
/* relax.grad.max_pool2d_backward */
Expr max_pool2d_backward(Expr output_grad, Expr data, Array<IntImm> pool_size,
Array<IntImm> strides, Array<IntImm> padding,
Array<IntImm> dilation,
- bool ceil_mode, String layout, Optional<String>
out_layout) {
+ bool ceil_mode, bool count_include_pad, String layout,
+ Optional<String> out_layout) {
auto attrs = make_object<Pool2DAttrs>();
attrs->pool_size = std::move(pool_size);
attrs->strides = ConvertIntImmToInt64(strides);
attrs->padding = ConvertIntImmToInt64(padding);
attrs->dilation = ConvertIntImmToInt64(dilation);
attrs->ceil_mode = ceil_mode;
+ attrs->count_include_pad = count_include_pad;
attrs->layout = layout;
attrs->out_layout = out_layout.value_or(layout);
static const Op& op = Op::Get("relax.grad.max_pool2d_backward");
@@ -160,13 +162,15 @@ TVM_REGISTER_OP("relax.grad.max_pool2d_backward")
/* relax.grad.avg_pool2d_backward */
Expr avg_pool2d_backward(Expr output_grad, Expr data, Array<IntImm> pool_size,
Array<IntImm> strides, Array<IntImm> padding,
Array<IntImm> dilation,
- bool ceil_mode, String layout, Optional<String>
out_layout) {
+ bool ceil_mode, bool count_include_pad, String layout,
+ Optional<String> out_layout) {
auto attrs = make_object<Pool2DAttrs>();
attrs->pool_size = std::move(pool_size);
attrs->strides = ConvertIntImmToInt64(strides);
attrs->padding = ConvertIntImmToInt64(padding);
attrs->dilation = ConvertIntImmToInt64(dilation);
attrs->ceil_mode = ceil_mode;
+ attrs->count_include_pad = count_include_pad;
attrs->layout = layout;
attrs->out_layout = out_layout.value_or(layout);
static const Op& op = Op::Get("relax.grad.avg_pool2d_backward");
diff --git a/src/relax/op/tensor/grad.h b/src/relax/op/tensor/grad.h
index 886516020d..228de315af 100644
--- a/src/relax/op/tensor/grad.h
+++ b/src/relax/op/tensor/grad.h
@@ -48,13 +48,15 @@ Expr nll_loss_backward(Expr output_grad, Expr predictions,
Expr targets, Optiona
* relax.max_pool2d. Returns the gradient w.r.t. data. */
Expr max_pool2d_backward(Expr output_grad, Expr data, Array<IntImm> pool_size,
Array<IntImm> strides, Array<IntImm> padding,
Array<IntImm> dilation,
- bool ceil_mode, String layout, Optional<String>
out_layout);
+ bool ceil_mode, bool count_include_pad, String layout,
+ Optional<String> out_layout);
/*! \brief Backward operator of relax.avg_pool2d. All parameters except
output_grad is the same as
* relax.avg_pool2d. Returns the gradient w.r.t. data. */
Expr avg_pool2d_backward(Expr output_grad, Expr data, Array<IntImm> pool_size,
Array<IntImm> strides, Array<IntImm> padding,
Array<IntImm> dilation,
- bool ceil_mode, String layout, Optional<String>
out_layout);
+ bool ceil_mode, bool count_include_pad, String layout,
+ Optional<String> out_layout);
/*! \brief Backward operator of relax.take. All parameters except output_grad
is the same as
* relax.take. Returns the gradient w.r.t. data. */
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 473766b749..32778cdd55 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1587,69 +1587,194 @@ def test_batch_norm():
check_correctness(model, opset=15)
-def test_max_pool():
- # Pool2D
- verify_unary(
- "MaxPool",
- [1, 1, 32, 32],
- dict(
- auto_pad="NOTSET",
- kernel_shape=[3, 3],
- pads=[1, 1, 1, 1],
- strides=[1, 1],
- ),
- )
- # Pool2D with stride
- verify_unary(
- "MaxPool",
- [1, 1, 32, 32],
- dict(
- auto_pad="NOTSET",
- kernel_shape=[3, 3],
- pads=[1, 1, 1, 1],
- strides=[2, 2],
- ),
- )
- # Pool2D with stride and autopadding
- verify_unary(
- "MaxPool",
- [1, 1, 32, 32],
- dict(
- auto_pad="SAME_UPPER",
- kernel_shape=[3, 7],
- pads=None,
- strides=[3, 2],
- ),
- )
- verify_unary(
- "MaxPool",
- [1, 1, 32, 32],
- dict(
- auto_pad="SAME_LOWER",
- kernel_shape=[3, 3],
- pads=None,
- strides=[2, 2],
- ),
- )
- verify_unary(
- "MaxPool",
- [1, 1, 32, 32],
- dict(
- auto_pad="VALID",
- kernel_shape=[3, 3],
- pads=None,
- strides=[2, 2],
- ),
- )
- verify_unary(
- "MaxPool",
- [1, 1, 32, 32],
- dict(
- auto_pad="SAME_UPPER",
- kernel_shape=[3, 3],
- pads=None,
- ),
- )
+def test_maxpool_and_averagepool():
+ for pool_name in ["MaxPool", "AveragePool"]:
+ # Pool1D
+ verify_unary(
+ pool_name,
+ [1, 1, 32],
+ dict(
+ auto_pad="NOTSET",
+ kernel_shape=[3],
+ pads=[1, 1],
+ strides=[1],
+ ),
+ )
+ # Pool1D with stride
+ verify_unary(
+ pool_name,
+ [1, 1, 32],
+ dict(
+ auto_pad="NOTSET",
+ kernel_shape=[3],
+ pads=[1, 2],
+ strides=[2],
+ ),
+ )
+ # Pool1D with stride and autopadding
+ verify_unary(
+ pool_name,
+ [1, 1, 32],
+ dict(
+ auto_pad="SAME_UPPER",
+ kernel_shape=[7],
+ pads=None,
+ strides=[2],
+ ),
+ )
+ verify_unary(
+ pool_name,
+ [1, 1, 32],
+ dict(
+ auto_pad="SAME_LOWER",
+ kernel_shape=[4],
+ pads=None,
+ strides=[4],
+ ),
+ )
+ verify_unary(
+ pool_name,
+ [1, 1, 32],
+ dict(
+ auto_pad="VALID",
+ kernel_shape=[5],
+ pads=None,
+ strides=[5],
+ ),
+ )
+ verify_unary(
+ pool_name,
+ [1, 1, 32],
+ dict(
+ auto_pad="SAME_UPPER",
+ kernel_shape=[3],
+ pads=None,
+ ),
+ )
+ # Pool2D
+ verify_unary(
+ pool_name,
+ [1, 1, 32, 32],
+ dict(
+ auto_pad="NOTSET",
+ kernel_shape=[3, 3],
+ pads=[1, 1, 1, 1],
+ strides=[1, 1],
+ ),
+ )
+ # Pool2D with stride
+ verify_unary(
+ pool_name,
+ [1, 1, 32, 32],
+ dict(
+ auto_pad="NOTSET",
+ kernel_shape=[3, 3],
+ pads=[1, 1, 1, 1],
+ strides=[2, 2],
+ ),
+ )
+ # Pool2D with stride and autopadding
+ verify_unary(
+ pool_name,
+ [1, 1, 32, 32],
+ dict(
+ auto_pad="SAME_UPPER",
+ kernel_shape=[3, 7],
+ pads=None,
+ strides=[3, 2],
+ ),
+ )
+ verify_unary(
+ pool_name,
+ [1, 1, 32, 32],
+ dict(
+ auto_pad="SAME_LOWER",
+ kernel_shape=[3, 3],
+ pads=None,
+ strides=[2, 2],
+ ),
+ )
+ verify_unary(
+ pool_name,
+ [1, 1, 32, 32],
+ dict(
+ auto_pad="VALID",
+ kernel_shape=[3, 3],
+ pads=None,
+ strides=[2, 2],
+ ),
+ )
+ verify_unary(
+ pool_name,
+ [1, 1, 32, 32],
+ dict(
+ auto_pad="SAME_UPPER",
+ kernel_shape=[3, 3],
+ pads=None,
+ ),
+ )
+ # Pool3D
+ verify_unary(
+ pool_name,
+ [1, 1, 32, 32, 32],
+ dict(
+ auto_pad="NOTSET",
+ kernel_shape=[3, 3, 4],
+ pads=[1, 2, 1, 1, 2, 2],
+ strides=[1, 1, 1],
+ ),
+ )
+ # Pool3D with stride
+ verify_unary(
+ pool_name,
+ [1, 1, 32, 32, 32],
+ dict(
+ auto_pad="NOTSET",
+ kernel_shape=[3, 4, 3],
+ pads=[1, 1, 1, 1, 1, 2],
+ strides=[2, 2, 3],
+ ),
+ )
+ # Pool3D with stride and autopadding
+ verify_unary(
+ pool_name,
+ [1, 1, 32, 32, 32],
+ dict(
+ auto_pad="SAME_UPPER",
+ kernel_shape=[4, 3, 3],
+ pads=None,
+ strides=[3, 2, 2],
+ ),
+ )
+ verify_unary(
+ pool_name,
+ [1, 1, 32, 32, 32],
+ dict(
+ auto_pad="SAME_LOWER",
+ kernel_shape=[3, 3, 4],
+ pads=None,
+ strides=[2, 2, 2],
+ ),
+ )
+ verify_unary(
+ pool_name,
+ [1, 1, 32, 32, 32],
+ dict(
+ auto_pad="VALID",
+ kernel_shape=[3, 3, 5],
+ pads=None,
+ strides=[2, 2, 3],
+ ),
+ )
+ verify_unary(
+ pool_name,
+ [1, 1, 32, 32, 32],
+ dict(
+ auto_pad="SAME_UPPER",
+ kernel_shape=[3, 3, 5],
+ pads=None,
+ ),
+ )
def test_global_average_pool():
diff --git a/tests/python/relax/test_op_gradient_numeric.py
b/tests/python/relax/test_op_gradient_numeric.py
index bc5cb0f5be..acf0f615dd 100644
--- a/tests/python/relax/test_op_gradient_numeric.py
+++ b/tests/python/relax/test_op_gradient_numeric.py
@@ -802,11 +802,17 @@ def test_conv2d(target, dev, c2d_shape1, c2d_shape2,
c2d_kwargs):
),
(
(3, 3),
- {"strides": (2, 2), "padding": (1, 2), "dilation": (1, 1)},
+ {"strides": (2, 2), "padding": (1, 2), "dilation": (1, 1),
"count_include_pad": True},
),
(
(5, 5),
- {"strides": (2, 2), "padding": (2, 1), "dilation": (1, 1),
"ceil_mode": True},
+ {
+ "strides": (2, 2),
+ "padding": (2, 1),
+ "dilation": (1, 1),
+ "ceil_mode": True,
+ "count_include_pad": True,
+ },
),
)
diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py
b/tests/python/relax/test_transform_legalize_ops_grad.py
index 19d1a106f8..f13748d2fa 100644
--- a/tests/python/relax/test_transform_legalize_ops_grad.py
+++ b/tests/python/relax/test_transform_legalize_ops_grad.py
@@ -282,8 +282,7 @@ def test_avg_pool2d_backward():
T.writes(T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3])
with T.init():
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0)
- T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] =
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 <
T.int64(3), T.int64(0), T.Div((v_ax2 - T.int64(3)), T.int64(2)) + T.int64(1))
<= T.Div((v_ax2 + T.int64(2)), T.int64(2)) - v_wh and T.Div((v_ax2 +
T.int64(2)), T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4),
T.int64(0), T.Div((v_ax3 - T.int64(4)), T.int64(2)) + T.int64(1)) <=
T.Div((v_ax3 + T.int64(1)), T.int64(2)) - v_ww and T.Div((v_ax [...]
-
+ T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] =
T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 <
T.int64(3), T.int64(0), T.Div(v_ax2 - T.int64(3), T.int64(2)) + T.int64(1)) <=
T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh and T.Div(v_ax2 + T.int64(2),
T.int64(2)) - v_wh < T.int64(6) and T.Select(v_ax3 < T.int64(4), T.int64(0),
T.Div(v_ax3 - T.int64(4), T.int64(2)) + T.int64(1)) <= T.Div(v_ax3 +
T.int64(1), T.int64(2)) - v_ww and T.Div(v_ax3 + T.int64 [...]
@R.function
def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data:
R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10),
dtype="float32"):
cls = Expected
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 29171daaae..92d139d23b 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -743,7 +743,7 @@ def test_avg_pool2d():
T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
- pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0,
v_ax1, v_ax2, v_ax3] / T.Cast("float32", (T.min(T.int64(1), T.int64(112) -
v_ax1 * T.int64(2)) + T.int64(2)) * (T.min(T.int64(1), T.int64(112) - v_ax2 *
T.int64(2)) + T.int64(2)))
+ pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0,
v_ax1, v_ax2, v_ax3] / T.Cast("float32", T.max((T.min(v_ax1 * T.int64(2) +
T.int64(1), T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax1 * T.int64(2),
T.int64(0)) - v_ax1 * T.int64(2)) * (T.min(v_ax2 * T.int64(2) + T.int64(1),
T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax2 * T.int64(2), T.int64(0))
- v_ax2 * T.int64(2)), T.int64(1)))
@R.function
def main(x: R.Tensor((4, 112, 112, 6), dtype="float32")) ->
R.Tensor((4, 56, 56, 6), dtype="float32"):
@@ -785,8 +785,7 @@ def test_avg_pool2d_NCHW16c():
T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
- pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] =
pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] / T.Cast("float32",
(T.min(T.int64(2), T.int64(111) - v_ax2) + T.int64(1)) * (T.min(T.int64(2),
T.int64(111) - v_ax3) + T.int64(1)))
-
+ pool_avg[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] =
pool_sum[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] / T.Cast("float32",
T.max((T.min(T.int64(2), T.int64(111) - v_ax2) + T.int64(1) - T.max(T.int64(0)
- v_ax2, T.int64(0))) * (T.min(T.int64(2), T.int64(111) - v_ax3) + T.int64(1) -
T.max(T.int64(0) - v_ax3, T.int64(0))), T.int64(1)))
@R.function
def main(x: R.Tensor((4, 4, 112, 112, 16), dtype="float32")) ->
R.Tensor((4, 4, 110, 110, 16), dtype="float32"):
gv = R.call_tir(Expected.avg_pool2d, (x,), out_sinfo=R.Tensor((4,
4, 110, 110, 16), dtype="float32"))
@@ -834,7 +833,7 @@ def test_avg_pool2d_ceil_mode():
T.reads(pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
T.block_attr({"schedule_rule": "meta_schedule.pool_avg"})
- pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0,
v_ax1, v_ax2, v_ax3] / T.Cast("float32", (T.min(T.int64(1), T.int64(112) -
v_ax2 * T.int64(3)) + T.int64(2)) * (T.min(T.int64(1), T.int64(112) - v_ax3 *
T.int64(3)) + T.int64(2)))
+ pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = pool_sum[v_ax0,
v_ax1, v_ax2, v_ax3] / T.Cast("float32", T.max((T.min(v_ax2 * T.int64(3) +
T.int64(1), T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax2 * T.int64(3),
T.int64(0)) - v_ax2 * T.int64(3)) * (T.min(v_ax3 * T.int64(3) + T.int64(1),
T.int64(111)) + T.int64(2) - T.max(T.int64(1) - v_ax3 * T.int64(3), T.int64(0))
- v_ax3 * T.int64(3)), T.int64(1)))
@R.function
def main(x: R.Tensor((4, 6, 112, 112), dtype="float32")) ->
R.Tensor((4, 6, 38, 38), dtype="float32"):