This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d11bdcd  [Op] Do not override specified layout in pooling (2nd PR) 
(#9328)
d11bdcd is described below

commit d11bdcd3ad0717b8e38ba769e849d6a6afe6415e
Author: Joe Chou <54378300+ccjoec...@users.noreply.github.com>
AuthorDate: Thu Oct 21 14:53:33 2021 -0700

    [Op] Do not override specified layout in pooling (2nd PR) (#9328)
    
    * [Op] Do not override specified layout in pooling (2nd PR)
    
    * [Op] Do not override specified layout in pooling (2nd PR)
    
    * [Op] Do not override specified layout in pooling (2nd PR)
    
    * [Op] Do not override specified layout in pooling (2nd PR)
---
 include/tvm/relay/attrs/nn.h                       |  78 ++++++
 python/tvm/relay/op/nn/_nn.py                      | 110 +++++++-
 python/tvm/relay/op/nn/nn.py                       | 180 ++++++++++---
 src/relay/op/nn/pooling.cc                         |  75 ++++--
 src/relay/op/nn/pooling.h                          |   6 +-
 src/relay/qnn/op/convolution.cc                    |   4 +
 src/relay/transforms/pattern_utils.h               |   7 +-
 .../contrib/test_arm_compute_lib/test_pooling.py   |   2 +
 tests/python/relay/test_pass_convert_op_layout.py  | 288 +++++++++++++++++++++
 9 files changed, 675 insertions(+), 75 deletions(-)

diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index de60deb..26d2c72 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -686,6 +686,7 @@ struct MaxPool2DAttrs : public 
tvm::AttrsNode<MaxPool2DAttrs> {
   Array<IndexExpr> padding;
   Array<IndexExpr> dilation;
   tvm::String layout;
+  tvm::String out_layout;
   bool ceil_mode;
 
   TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") {
@@ -709,6 +710,13 @@ struct MaxPool2DAttrs : public 
tvm::AttrsNode<MaxPool2DAttrs> {
         "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
         "dimensions respectively. Pooling is applied on the 'H' and"
         "'W' dimensions.");
+    TVM_ATTR_FIELD(out_layout)
+        .set_default("")
+        .describe(
+            "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
+            "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+            "dimensions respectively. Pooling is applied on the 'H' and"
+            "'W' dimensions.");
     TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
         "When true, will use ceil instead of floor to compute the output 
shape.");
   }
@@ -721,6 +729,7 @@ struct AvgPool2DAttrs : public 
tvm::AttrsNode<AvgPool2DAttrs> {
   Array<IndexExpr> padding;
   Array<IndexExpr> dilation;
   tvm::String layout;
+  tvm::String out_layout;
   bool ceil_mode;
   bool count_include_pad;
 
@@ -745,6 +754,13 @@ struct AvgPool2DAttrs : public 
tvm::AttrsNode<AvgPool2DAttrs> {
         "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
         "dimensions respectively. Pooling is applied on the 'H' and"
         "'W' dimensions.");
+    TVM_ATTR_FIELD(out_layout)
+        .set_default("")
+        .describe(
+            "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
+            "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+            "dimensions respectively. Pooling is applied on the 'H' and"
+            "'W' dimensions.");
     TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
         "When true, will use ceil instead of floor to compute the output 
shape.");
     TVM_ATTR_FIELD(count_include_pad)
@@ -756,6 +772,7 @@ struct AvgPool2DAttrs : public 
tvm::AttrsNode<AvgPool2DAttrs> {
 /*! \brief Attributes for global pool operator */
 struct GlobalPool2DAttrs : public tvm::AttrsNode<GlobalPool2DAttrs> {
   tvm::String layout;
+  tvm::String out_layout;
 
   TVM_DECLARE_ATTRS(GlobalPool2DAttrs, "relay.attrs.GlobalPool2DAttrs") {
     TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
@@ -763,6 +780,13 @@ struct GlobalPool2DAttrs : public 
tvm::AttrsNode<GlobalPool2DAttrs> {
         "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
         "dimensions respectively. Pooling is applied on the 'H' and"
         "'W' dimensions.");
+    TVM_ATTR_FIELD(out_layout)
+        .set_default("")
+        .describe(
+            "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
+            "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+            "dimensions respectively. Pooling is applied on the 'H' and"
+            "'W' dimensions.");
   }
 };
 
@@ -770,6 +794,7 @@ struct GlobalPool2DAttrs : public 
tvm::AttrsNode<GlobalPool2DAttrs> {
 struct AdaptivePool1DAttrs : public tvm::AttrsNode<AdaptivePool1DAttrs> {
   Array<IndexExpr> output_size;
   std::string layout;
+  tvm::String out_layout;
 
   TVM_DECLARE_ATTRS(AdaptivePool1DAttrs, "relay.attrs.AdaptivePool1DAttrs") {
     
TVM_ATTR_FIELD(output_size).set_default(Array<IndexExpr>({})).describe("Output 
width.");
@@ -778,6 +803,13 @@ struct AdaptivePool1DAttrs : public 
tvm::AttrsNode<AdaptivePool1DAttrs> {
         "'N', 'C', 'W' stands for batch, channel, and width"
         "dimensions respectively. Pooling is applied on the"
         "'W' dimension.");
+    TVM_ATTR_FIELD(out_layout)
+        .set_default("")
+        .describe(
+            "Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
+            "'N', 'C', 'W' stands for batch, channel, and width"
+            "dimensions respectively. Pooling is applied on the"
+            "'W' dimension.");
   }
 };
 
@@ -785,6 +817,7 @@ struct AdaptivePool1DAttrs : public 
tvm::AttrsNode<AdaptivePool1DAttrs> {
 struct AdaptivePool2DAttrs : public tvm::AttrsNode<AdaptivePool2DAttrs> {
   Array<IndexExpr> output_size;
   std::string layout;
+  tvm::String out_layout;
 
   TVM_DECLARE_ATTRS(AdaptivePool2DAttrs, "relay.attrs.AdaptivePool2DAttrs") {
     TVM_ATTR_FIELD(output_size)
@@ -795,6 +828,13 @@ struct AdaptivePool2DAttrs : public 
tvm::AttrsNode<AdaptivePool2DAttrs> {
         "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
         "dimensions respectively. Pooling is applied on the 'H' and"
         "'W' dimensions.");
+    TVM_ATTR_FIELD(out_layout)
+        .set_default("")
+        .describe(
+            "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
+            "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+            "dimensions respectively. Pooling is applied on the 'H' and"
+            "'W' dimensions.");
   }
 };
 
@@ -802,6 +842,7 @@ struct AdaptivePool2DAttrs : public 
tvm::AttrsNode<AdaptivePool2DAttrs> {
 struct AdaptivePool3DAttrs : public tvm::AttrsNode<AdaptivePool3DAttrs> {
   Array<IndexExpr> output_size;
   std::string layout;
+  tvm::String out_layout;
 
   TVM_DECLARE_ATTRS(AdaptivePool3DAttrs, "relay.attrs.AdaptivePool3DAttrs") {
     TVM_ATTR_FIELD(output_size)
@@ -812,6 +853,13 @@ struct AdaptivePool3DAttrs : public 
tvm::AttrsNode<AdaptivePool3DAttrs> {
         "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and 
width"
         "dimensions respectively. Pooling is applied on 'D', 'H' and"
         "'W' dimensions.");
+    TVM_ATTR_FIELD(out_layout)
+        .set_default("")
+        .describe(
+            "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
+            "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, 
and width"
+            "dimensions respectively. Pooling is applied on 'D', 'H' and"
+            "'W' dimensions.");
   }
 };
 
@@ -822,6 +870,7 @@ struct MaxPool1DAttrs : public 
tvm::AttrsNode<MaxPool1DAttrs> {
   Array<IndexExpr> dilation;
   Array<IndexExpr> padding;
   std::string layout;
+  tvm::String out_layout;
   bool ceil_mode;
 
   TVM_DECLARE_ATTRS(MaxPool1DAttrs, "relay.attrs.MaxPool1DAttrs") {
@@ -844,6 +893,12 @@ struct MaxPool1DAttrs : public 
tvm::AttrsNode<MaxPool1DAttrs> {
         "Dimension ordering of input data. Can be 'NCW', 'NWC', etc."
         "'N', 'C', 'W' stands for batch, channel, and width"
         "dimensions respectively. Pooling is applied on the 'W' dimensions.");
+    TVM_ATTR_FIELD(out_layout)
+        .set_default("")
+        .describe(
+            "Dimension ordering of output data. Can be 'NCW', 'NWC', etc."
+            "'N', 'C', 'W' stands for batch, channel, and width"
+            "dimensions respectively. Pooling is applied on the 'W' 
dimensions.");
     TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
         "When true, will use ceil instead of floor to compute the output 
shape.");
   }
@@ -856,6 +911,7 @@ struct AvgPool1DAttrs : public 
tvm::AttrsNode<AvgPool1DAttrs> {
   Array<IndexExpr> dilation;
   Array<IndexExpr> padding;
   std::string layout;
+  tvm::String out_layout;
   bool ceil_mode;
   bool count_include_pad;
 
@@ -879,6 +935,12 @@ struct AvgPool1DAttrs : public 
tvm::AttrsNode<AvgPool1DAttrs> {
         "Dimension ordering of input data. Can be 'NCW', 'NHC', etc."
         "'N', 'C', 'W' stands for batch, channel, and width"
         "dimensions respectively. Pooling is applied on the 'W' dimension.");
+    TVM_ATTR_FIELD(out_layout)
+        .set_default("")
+        .describe(
+            "Dimension ordering of output data. Can be 'NCW', 'NHC', etc."
+            "'N', 'C', 'W' stands for batch, channel, and width"
+            "dimensions respectively. Pooling is applied on the 'W' 
dimension.");
     TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
         "When true, will use ceil instead of floor to compute the output 
shape.");
     TVM_ATTR_FIELD(count_include_pad)
@@ -894,6 +956,7 @@ struct MaxPool3DAttrs : public 
tvm::AttrsNode<MaxPool3DAttrs> {
   Array<IndexExpr> dilation;
   Array<IndexExpr> padding;
   std::string layout;
+  tvm::String out_layout;
   bool ceil_mode;
 
   TVM_DECLARE_ATTRS(MaxPool3DAttrs, "relay.attrs.MaxPool3DAttrs") {
@@ -917,6 +980,13 @@ struct MaxPool3DAttrs : public 
tvm::AttrsNode<MaxPool3DAttrs> {
         "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and 
width"
         "dimensions respectively. Pooling is applied on the 'D', 'H' and"
         "'W' dimensions.");
+    TVM_ATTR_FIELD(out_layout)
+        .set_default("")
+        .describe(
+            "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
+            "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, 
and width"
+            "dimensions respectively. Pooling is applied on the 'D', 'H' and"
+            "'W' dimensions.");
     TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
         "When true, will use ceil instead of floor to compute the output 
shape.");
   }
@@ -929,6 +999,7 @@ struct AvgPool3DAttrs : public 
tvm::AttrsNode<AvgPool3DAttrs> {
   Array<IndexExpr> dilation;
   Array<IndexExpr> padding;
   std::string layout;
+  tvm::String out_layout;
   bool ceil_mode;
   bool count_include_pad;
 
@@ -953,6 +1024,13 @@ struct AvgPool3DAttrs : public 
tvm::AttrsNode<AvgPool3DAttrs> {
         "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and 
width"
         "dimensions respectively. Pooling is applied on the 'D', 'H' and"
         "'W' dimensions.");
+    TVM_ATTR_FIELD(out_layout)
+        .set_default("")
+        .describe(
+            "Dimension ordering of output data. Can be 'NCDHW', 'NDHWC', etc."
+            "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, 
and width"
+            "dimensions respectively. Pooling is applied on the 'D', 'H' and"
+            "'W' dimensions.");
     TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
         "When true, will use ceil instead of floor to compute the output 
shape.");
     TVM_ATTR_FIELD(count_include_pad)
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index f06ee09..17f75a0 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -18,7 +18,7 @@
 """Backend compiler related feature registration"""
 from __future__ import absolute_import
 
-from tvm import topi
+from tvm import topi, relay
 from tvm.topi.utils import get_const_tuple
 
 from tvm.runtime import convert
@@ -267,9 +267,6 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layouts):
     result : tvm.relay.Expr
         The transformed expr
     """
-    # pylint: disable=import-outside-toplevel
-    from tvm import relay
-
     data, weight = inputs
 
     # First check if there is a LayoutConfig scope, and if so, whether
@@ -363,9 +360,6 @@ def convert_conv2d_transpose(attrs, inputs, tinfos, 
desired_layouts):
     result : tvm.relay.Expr
         The transformed expr
     """
-    # pylint: disable=import-outside-toplevel
-    from tvm import relay
-
     data, weight = inputs
     new_attrs = dict(attrs)
     assert len(desired_layouts) == 2, "A desired layout is expected for both 
of nn.conv2d's inputs"
@@ -446,9 +440,6 @@ def convert_conv3d(attrs, inputs, tinfos, desired_layouts):
     result : tvm.relay.Expr
         The transformed expr
     """
-    # pylint: disable=import-outside-toplevel
-    from tvm import relay
-
     data, weight = inputs
     new_attrs = dict(attrs)
     assert len(desired_layouts) == 2, "A desired layout is expected for both 
of nn.conv3d's inputs"
@@ -515,6 +506,30 @@ reg.register_schedule("nn.max_pool2d", 
strategy.schedule_pool)
 reg.register_pattern("nn.max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
+@reg.register_convert_op_layout("nn.max_pool2d")
+def convert_max_pool2d(attrs, inputs, tinfos, desired_layouts):
+    """Convert Layout pass registration for max_pool2d op.
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current pooling
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    tinfos : list of types
+        List of input and output types
+    desired_layouts : list of one layout string
+        layout string defining our desired layout for input and output.
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The transformed expr
+    """
+    new_attrs = dict(attrs)
+    new_attrs["layout"] = str(desired_layouts[0])
+    new_attrs["out_layout"] = str(desired_layouts[0])
+    return relay.nn.max_pool2d(*inputs, **new_attrs)
+
+
 # max_pool3d
 reg.register_schedule("nn.max_pool3d", strategy.schedule_pool)
 reg.register_pattern("nn.max_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
@@ -530,6 +545,30 @@ reg.register_schedule("nn.avg_pool2d", 
strategy.schedule_pool)
 reg.register_pattern("nn.avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
+@reg.register_convert_op_layout("nn.avg_pool2d")
+def convert_avg_pool2d(attrs, inputs, tinfos, desired_layouts):
+    """Convert Layout pass registration for avg_pool2d op.
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current pooling
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    tinfos : list of types
+        List of input and output types
+    desired_layouts : list of one layout string
+        layout string defining our desired layout for input and output.
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The transformed expr
+    """
+    new_attrs = dict(attrs)
+    new_attrs["layout"] = str(desired_layouts[0])
+    new_attrs["out_layout"] = str(desired_layouts[0])
+    return relay.nn.avg_pool2d(*inputs, **new_attrs)
+
+
 # avg_pool3d
 reg.register_schedule("nn.avg_pool3d", strategy.schedule_pool)
 reg.register_pattern("nn.avg_pool3d", OpPattern.OUT_ELEMWISE_FUSABLE)
@@ -560,11 +599,59 @@ reg.register_schedule("nn.global_max_pool2d", 
strategy.schedule_adaptive_pool)
 reg.register_pattern("nn.global_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
+@reg.register_convert_op_layout("nn.global_max_pool2d")
+def convert_global_max_pool2d(attrs, inputs, tinfos, desired_layouts):
+    """Convert Layout pass registration for global_max_pool2d op.
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current pooling
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    tinfos : list of types
+        List of input and output types
+    desired_layouts : list of one layout string
+        layout string defining our desired layout for input and output.
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The transformed expr
+    """
+    new_attrs = dict(attrs)
+    new_attrs["layout"] = str(desired_layouts[0])
+    new_attrs["out_layout"] = str(desired_layouts[0])
+    return relay.nn.global_max_pool2d(*inputs, **new_attrs)
+
+
 # global_avg_pool2d
 reg.register_schedule("nn.global_avg_pool2d", strategy.schedule_adaptive_pool)
 reg.register_pattern("nn.global_avg_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
 
 
+@reg.register_convert_op_layout("nn.global_avg_pool2d")
+def convert_global_avg_pool2d(attrs, inputs, tinfos, desired_layouts):
+    """Convert Layout pass registration for global_avg_pool2d op.
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current pooling
+    inputs : list of tvm.relay.Expr
+        The args of the Relay expr to be legalized
+    tinfos : list of types
+        List of input and output types
+    desired_layouts : list of one layout string
+        layout string defining our desired layout for input and output.
+    Returns
+    -------
+    result : tvm.relay.Expr
+        The transformed expr
+    """
+    new_attrs = dict(attrs)
+    new_attrs["layout"] = str(desired_layouts[0])
+    new_attrs["out_layout"] = str(desired_layouts[0])
+    return relay.nn.global_avg_pool2d(*inputs, **new_attrs)
+
+
 # adaptive_max_pool2d
 reg.register_schedule("nn.adaptive_max_pool2d", 
strategy.schedule_adaptive_pool)
 reg.register_pattern("nn.adaptive_max_pool2d", OpPattern.OUT_ELEMWISE_FUSABLE)
@@ -796,9 +883,6 @@ def convert_deformable_conv2d(attrs, inputs, tinfos, 
desired_layouts):
     result : tvm.relay.Expr
         The transformed expr
     """
-    # pylint: disable=import-outside-toplevel
-    from tvm import relay
-
     data, offset, weight = inputs
     new_attrs = dict(attrs)
     for attr in new_attrs:
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index 5a17db7..1821ff1 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -748,7 +748,14 @@ def log_softmax(data, axis=-1):
 
 
 def max_pool1d(
-    data, pool_size=(1,), strides=(1,), dilation=(1,), padding=(0,), 
layout="NCW", ceil_mode=False
+    data,
+    pool_size=(1,),
+    strides=(1,),
+    dilation=(1,),
+    padding=(0,),
+    layout="NCW",
+    out_layout="",
+    ceil_mode=False,
 ):
     r"""1D maximum pooling operator.
 
@@ -783,6 +790,9 @@ def max_pool1d(
     layout : str, optional
         Layout of the input.
 
+    out_layout : Optional[str]
+        Layout of the output
+
     ceil_mode : bool, optional
         To enable or disable ceil while pooling.
 
@@ -798,7 +808,9 @@ def max_pool1d(
     if isinstance(dilation, int):
         dilation = (dilation,)
     padding = get_pad_tuple1d(padding)
-    return _make.max_pool1d(data, pool_size, strides, dilation, padding, 
layout, ceil_mode)
+    return _make.max_pool1d(
+        data, pool_size, strides, dilation, padding, layout, out_layout, 
ceil_mode
+    )
 
 
 def max_pool2d(
@@ -808,6 +820,7 @@ def max_pool2d(
     dilation=(1, 1),
     padding=(0, 0),
     layout="NCHW",
+    out_layout="",
     ceil_mode=False,
 ):
     r"""2D maximum pooling operator.
@@ -851,6 +864,9 @@ def max_pool2d(
     layout : str, optional
         Layout of the input.
 
+    out_layout : Optional[str]
+        Layout of the output
+
     ceil_mode : bool, optional
         To enable or disable ceil while pooling.
 
@@ -866,7 +882,9 @@ def max_pool2d(
     if isinstance(dilation, int):
         dilation = (dilation, dilation)
     padding = get_pad_tuple2d(padding)
-    return _make.max_pool2d(data, pool_size, strides, dilation, padding, 
layout, ceil_mode)
+    return _make.max_pool2d(
+        data, pool_size, strides, dilation, padding, layout, out_layout, 
ceil_mode
+    )
 
 
 def max_pool3d(
@@ -876,6 +894,7 @@ def max_pool3d(
     dilation=(1, 1, 1),
     padding=(0, 0, 0),
     layout="NCDHW",
+    out_layout="",
     ceil_mode=False,
 ):
     r"""3D maximum pooling operator.
@@ -912,6 +931,9 @@ def max_pool3d(
     layout : str, optional
         Layout of the input.
 
+    out_layout : Optional[str]
+        Layout of the output
+
     ceil_mode : bool, optional
         To enable or disable ceil while pooling.
 
@@ -927,7 +949,9 @@ def max_pool3d(
     if isinstance(dilation, int):
         dilation = (dilation, dilation, dilation)
     padding = get_pad_tuple3d(padding)
-    return _make.max_pool3d(data, pool_size, strides, dilation, padding, 
layout, ceil_mode)
+    return _make.max_pool3d(
+        data, pool_size, strides, dilation, padding, layout, out_layout, 
ceil_mode
+    )
 
 
 def avg_pool1d(
@@ -937,6 +961,7 @@ def avg_pool1d(
     dilation=(1,),
     padding=(0,),
     layout="NCW",
+    out_layout="",
     ceil_mode=False,
     count_include_pad=False,
 ):
@@ -973,6 +998,9 @@ def avg_pool1d(
     layout : str, optional
         Layout of the input.
 
+    out_layout : Optional[str]
+        Layout of the output
+
     ceil_mode : bool, optional
         To enable or disable ceil while pooling.
 
@@ -992,7 +1020,15 @@ def avg_pool1d(
         dilation = (dilation,)
     padding = get_pad_tuple1d(padding)
     return _make.avg_pool1d(
-        data, pool_size, strides, dilation, padding, layout, ceil_mode, 
count_include_pad
+        data,
+        pool_size,
+        strides,
+        dilation,
+        padding,
+        layout,
+        out_layout,
+        ceil_mode,
+        count_include_pad,
     )
 
 
@@ -1003,6 +1039,7 @@ def avg_pool2d(
     dilation=(1, 1),
     padding=(0, 0),
     layout="NCHW",
+    out_layout="",
     ceil_mode=False,
     count_include_pad=False,
 ):
@@ -1048,6 +1085,9 @@ def avg_pool2d(
     layout : str, optional
         Layout of the input.
 
+    out_layout : Optional[str]
+        Layout of the output
+
     ceil_mode : bool, optional
         To enable or disable ceil while pooling.
 
@@ -1067,7 +1107,15 @@ def avg_pool2d(
         dilation = (dilation, dilation)
     padding = get_pad_tuple2d(padding)
     return _make.avg_pool2d(
-        data, pool_size, strides, dilation, padding, layout, ceil_mode, 
count_include_pad
+        data,
+        pool_size,
+        strides,
+        dilation,
+        padding,
+        layout,
+        out_layout,
+        ceil_mode,
+        count_include_pad,
     )
 
 
@@ -1078,6 +1126,7 @@ def avg_pool3d(
     dilation=(1, 1, 1),
     padding=(0, 0, 0),
     layout="NCDHW",
+    out_layout="",
     ceil_mode=False,
     count_include_pad=False,
 ):
@@ -1115,6 +1164,9 @@ def avg_pool3d(
     layout : str, optional
         Layout of the input.
 
+    out_layout : Optional[str]
+        Layout of the output
+
     ceil_mode : bool, optional
         To enable or disable ceil while pooling.
 
@@ -1134,7 +1186,15 @@ def avg_pool3d(
         dilation = (dilation, dilation, dilation)
     padding = get_pad_tuple3d(padding)
     return _make.avg_pool3d(
-        data, pool_size, strides, dilation, padding, layout, ceil_mode, 
count_include_pad
+        data,
+        pool_size,
+        strides,
+        dilation,
+        padding,
+        layout,
+        out_layout,
+        ceil_mode,
+        count_include_pad,
     )
 
 
@@ -1145,6 +1205,7 @@ def max_pool2d_grad(
     strides=(1, 1),
     padding=(0, 0),
     layout="NCHW",
+    out_layout="",
     ceil_mode=False,
 ):
     r"""Gradient of 2D maximum pooling operator.
@@ -1171,6 +1232,9 @@ def max_pool2d_grad(
     layout : str, optional
         Layout of the input.
 
+    out_layout : Optional[str]
+        Layout of the output
+
     ceil_mode : bool, optional
         To enable or disable ceil while pooling.
 
@@ -1179,7 +1243,9 @@ def max_pool2d_grad(
     result : tvm.relay.Expr
         The computed result.
     """
-    return _make.max_pool2d_grad(out_grad, data, pool_size, strides, padding, 
layout, ceil_mode)
+    return _make.max_pool2d_grad(
+        out_grad, data, pool_size, strides, padding, layout, out_layout, 
ceil_mode
+    )
 
 
 def avg_pool2d_grad(
@@ -1189,6 +1255,7 @@ def avg_pool2d_grad(
     strides=(1, 1),
     padding=(0, 0),
     layout="NCHW",
+    out_layout="",
     ceil_mode=False,
     count_include_pad=False,
 ):
@@ -1216,6 +1283,9 @@ def avg_pool2d_grad(
     layout : str, optional
         Layout of the input.
 
+    out_layout : Optional[str]
+        Layout of the output
+
     ceil_mode : bool, optional
         To enable or disable ceil while pooling.
 
@@ -1228,11 +1298,19 @@ def avg_pool2d_grad(
         The computed result.
     """
     return _make.avg_pool2d_grad(
-        out_grad, data, pool_size, strides, padding, layout, ceil_mode, 
count_include_pad
+        out_grad,
+        data,
+        pool_size,
+        strides,
+        padding,
+        layout,
+        out_layout,
+        ceil_mode,
+        count_include_pad,
     )
 
 
-def global_max_pool2d(data, layout="NCHW"):
+def global_max_pool2d(data, layout="NCHW", out_layout=""):
     r"""2D global maximum pooling operator.
 
     This operator takes data as input and does 2D max value calculation
@@ -1258,15 +1336,18 @@ def global_max_pool2d(data, layout="NCHW"):
     layout : str, optional
         Layout of the input.
 
+    out_layout : Optional[str]
+        Layout of the output
+
     Returns
     -------
     result : tvm.relay.Expr
         The computed result.
     """
-    return _make.global_max_pool2d(data, layout)
+    return _make.global_max_pool2d(data, layout, out_layout)
 
 
-def global_avg_pool2d(data, layout="NCHW"):
+def global_avg_pool2d(data, layout="NCHW", out_layout=""):
     r"""2D global average pooling operator.
 
     This operator takes data as input and does 2D average value calculation
@@ -1292,12 +1373,15 @@ def global_avg_pool2d(data, layout="NCHW"):
     layout : str, optional
         Layout of the input.
 
+    out_layout : Optional[str]
+        Layout of the output
+
     Returns
     -------
     result : tvm.relay.Expr
         The computed result.
     """
-    return _make.global_avg_pool2d(data, layout)
+    return _make.global_avg_pool2d(data, layout, out_layout)
 
 
 def upsampling(
@@ -3114,7 +3198,7 @@ def space_to_depth(data, block_size, layout="NCHW"):
     return _make.space_to_depth(data, block_size, layout)
 
 
-def adaptive_max_pool1d(data, output_size=None, layout="NCW"):
+def adaptive_max_pool1d(data, output_size=None, layout="NCW", out_layout=""):
     r"""1D adaptive max pooling operator. This operator is experimental.
 
     This operator takes data as input and does 1D max value calculation
@@ -3147,6 +3231,9 @@ def adaptive_max_pool1d(data, output_size=None, 
layout="NCW"):
     layout : str, optional
         Layout of the input.
 
+    out_layout : str, optional
+        Layout of the output.
+
     Returns
     -------
     result : tvm.relay.Expr
@@ -3155,10 +3242,10 @@ def adaptive_max_pool1d(data, output_size=None, 
layout="NCW"):
     output_size = [] or output_size
     if isinstance(output_size, int):
         output_size = [output_size]
-    return _make.adaptive_max_pool1d(data, output_size, layout)
+    return _make.adaptive_max_pool1d(data, output_size, layout, out_layout)
 
 
-def adaptive_avg_pool1d(data, output_size=None, layout="NCW"):
+def adaptive_avg_pool1d(data, output_size=None, layout="NCW", out_layout=""):
     r"""1D adaptive average pooling operator. This operator is experimental.
 
     This operator takes data as input and does 1D average value calculation
@@ -3191,6 +3278,9 @@ def adaptive_avg_pool1d(data, output_size=None, 
layout="NCW"):
     layout : str, optional
         Layout of the input.
 
+    out_layout : str, optional
+        Layout of the output.
+
     Returns
     -------
     result : tvm.relay.Expr
@@ -3199,10 +3289,10 @@ def adaptive_avg_pool1d(data, output_size=None, 
layout="NCW"):
     output_size = [] or output_size
     if isinstance(output_size, int):
         output_size = [output_size]
-    return _make.adaptive_avg_pool1d(data, output_size, layout)
+    return _make.adaptive_avg_pool1d(data, output_size, layout, out_layout)
 
 
-def adaptive_max_pool2d(data, output_size=None, layout="NCHW"):
+def adaptive_max_pool2d(data, output_size=None, layout="NCHW", out_layout=""):
     r"""2D adaptive max pooling operator. This operator is experimental.
 
     This operator takes data as input and does 2D max value calculation
@@ -3238,16 +3328,19 @@ def adaptive_max_pool2d(data, output_size=None, 
layout="NCHW"):
     layout : str, optional
         Layout of the input.
 
+    out_layout : str, optional
+        Layout of the output.
+
     Returns
     -------
     result : tvm.relay.Expr
         The computed result.
     """
     output_size = [] or output_size
-    return _make.adaptive_max_pool2d(data, output_size, layout)
+    return _make.adaptive_max_pool2d(data, output_size, layout, out_layout)
 
 
-def adaptive_avg_pool2d(data, output_size=None, layout="NCHW"):
+def adaptive_avg_pool2d(data, output_size=None, layout="NCHW", out_layout=""):
     r"""2D adaptive average pooling operator. This operator is experimental.
 
     This operator takes data as input and does 2D average value calculation
@@ -3283,16 +3376,19 @@ def adaptive_avg_pool2d(data, output_size=None, 
layout="NCHW"):
     layout : str, optional
         Layout of the input.
 
+    out_layout : str, optional
+        Layout of the output.
+
     Returns
     -------
     result : tvm.relay.Expr
         The computed result.
     """
     output_size = [] or output_size
-    return _make.adaptive_avg_pool2d(data, output_size, layout)
+    return _make.adaptive_avg_pool2d(data, output_size, layout, out_layout)
 
 
-def adaptive_max_pool3d(data, output_size=None, layout="NCDHW"):
+def adaptive_max_pool3d(data, output_size=None, layout="NCDHW", out_layout=""):
     r"""3D adaptive max pooling operator. This operator is experimental.
 
     This operator takes data as input and does 3D max value calculation
@@ -3327,16 +3423,19 @@ def adaptive_max_pool3d(data, output_size=None, 
layout="NCDHW"):
     layout : str, optional
         Layout of the input.
 
+    out_layout : str, optional
+        Layout of the output.
+
     Returns
     -------
     result : tvm.relay.Expr
         The computed result.
     """
     output_size = [] or output_size
-    return _make.adaptive_max_pool3d(data, output_size, layout)
+    return _make.adaptive_max_pool3d(data, output_size, layout, out_layout)
 
 
-def adaptive_avg_pool3d(data, output_size=None, layout="NCDHW"):
+def adaptive_avg_pool3d(data, output_size=None, layout="NCDHW", out_layout=""):
     r"""3D adaptive avg pooling operator. This operator is experimental.
 
     This operator takes data as input and does 3D avg value calculation
@@ -3371,16 +3470,19 @@ def adaptive_avg_pool3d(data, output_size=None, 
layout="NCDHW"):
     layout : str, optional
         Layout of the input.
 
+    out_layout : str, optional
+        Layout of the output.
+
     Returns
     -------
     result : tvm.relay.Expr
         The computed result.
     """
     output_size = [] or output_size
-    return _make.adaptive_avg_pool3d(data, output_size, layout)
+    return _make.adaptive_avg_pool3d(data, output_size, layout, out_layout)
 
 
-def global_max_pool1d(data, layout="NCW"):
+def global_max_pool1d(data, layout="NCW", out_layout=""):
     r"""1D global maximum pooling operator.
 
     This operator takes data as input and does 1D max value calculation
@@ -3403,16 +3505,19 @@ def global_max_pool1d(data, layout="NCW"):
     layout : str, optional
         Layout of the input.
 
+    out_layout : str, optional
+        Layout of the output.
+
     Returns
     -------
     result : tvm.relay.Expr
         The computed result.
     """
     output_size = [1]
-    return _make.adaptive_max_pool1d(data, output_size, layout)
+    return _make.adaptive_max_pool1d(data, output_size, layout, out_layout)
 
 
-def global_avg_pool1d(data, layout="NCW"):
+def global_avg_pool1d(data, layout="NCW", out_layout=""):
     r"""1D global average pooling operator.
 
     This operator takes data as input and does 1D average value calculation
@@ -3436,16 +3541,19 @@ def global_avg_pool1d(data, layout="NCW"):
     layout : str, optional
         Layout of the input.
 
+    out_layout : str, optional
+        Layout of the output.
+
     Returns
     -------
     result : tvm.relay.Expr
         The computed result.
     """
     output_size = [1]
-    return _make.adaptive_avg_pool1d(data, output_size, layout)
+    return _make.adaptive_avg_pool1d(data, output_size, layout, out_layout)
 
 
-def global_max_pool3d(data, layout="NCDHW"):
+def global_max_pool3d(data, layout="NCDHW", out_layout=""):
     r"""3D global maximum pooling operator.
 
     This operator takes data as input and does 3D max value calculation
@@ -3469,16 +3577,19 @@ def global_max_pool3d(data, layout="NCDHW"):
     layout : str, optional
         Layout of the input.
 
+    out_layout : str, optional
+        Layout of the output.
+
     Returns
     -------
     result : tvm.relay.Expr
         The computed result.
     """
     output_size = [1, 1, 1]
-    return _make.adaptive_max_pool3d(data, output_size, layout)
+    return _make.adaptive_max_pool3d(data, output_size, layout, out_layout)
 
 
-def global_avg_pool3d(data, layout="NCDHW"):
+def global_avg_pool3d(data, layout="NCDHW", out_layout=""):
     r"""3D global average pooling operator.
 
     This operator takes data as input and does 3D average value calculation
@@ -3503,13 +3614,16 @@ def global_avg_pool3d(data, layout="NCDHW"):
     layout : str, optional
         Layout of the input.
 
+    out_layout : str, optional
+        Layout of the output.
+
     Returns
     -------
     result : tvm.relay.Expr
         The computed result.
     """
     output_size = [1, 1, 1]
-    return _make.adaptive_avg_pool3d(data, output_size, layout)
+    return _make.adaptive_avg_pool3d(data, output_size, layout, out_layout)
 
 
 def correlation(
diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc
index 0d40caa..cf44b30 100644
--- a/src/relay/op/nn/pooling.cc
+++ b/src/relay/op/nn/pooling.cc
@@ -49,8 +49,13 @@ InferCorrectLayoutOutput PoolInferCorrectLayout(const Attrs& 
attrs,
   ICHECK(attrs_ptr);
   ObjectPtr<T> params = make_object<T>(*attrs_ptr);
 
-  if (new_in_layouts.defined()) {
-    // Set the pool with the new layout.
+  if (params->out_layout != "") {
+    // when users specify the out_layout of pooling, follow user's preference
+    ICHECK_EQ(params->layout, params->out_layout)
+        << "Pooling input/output layouts mismatch: " << params->layout << " 
vs. "
+        << params->out_layout;
+  } else if (new_in_layouts.defined()) {
+    // the pooling is using an inferred layout (i.e., new_in_layouts[0]) given 
by relay caller
     ICHECK_EQ(new_in_layouts.size(), 1);
     params->layout = new_in_layouts[0].name();
   }
@@ -144,6 +149,7 @@ Array<te::Tensor> Pool2DCompute(const Attrs& attrs, const 
Array<te::Tensor>& inp
   auto padding = param->padding;
   auto ceil_mode = param->ceil_mode;
   Layout layout(param->layout);
+  Layout out_layout(param->out_layout);
 
   ICHECK(tir::BijectiveLayout(layout, kNCHW).defined())
       << "max_pool2d currently only supports layouts that are convertible from 
NCHW";
@@ -178,9 +184,9 @@ Array<te::Tensor> Pool2DCompute(const Attrs& attrs, const 
Array<te::Tensor>& inp
 TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool2d")
     .set_body_typed([](Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> 
strides,
                        Array<IndexExpr> dilation, Array<IndexExpr> padding, 
String layout,
-                       bool ceil_mode) {
+                       String out_layout, bool ceil_mode) {
       return MakeMaxPool<MaxPool2DAttrs>(data, pool_size, strides, dilation, 
padding, layout,
-                                         ceil_mode, "nn.max_pool2d");
+                                         out_layout, ceil_mode, 
"nn.max_pool2d");
     });
 
 RELAY_REGISTER_OP("nn.max_pool2d")
@@ -216,9 +222,9 @@ RELAY_REGISTER_OP("nn.max_pool2d")
 TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool2d")
     .set_body_typed([](Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> 
strides,
                        Array<IndexExpr> dilation, Array<IndexExpr> padding, 
String layout,
-                       bool ceil_mode, bool count_include_pad) {
+                       String out_layout, bool ceil_mode, bool 
count_include_pad) {
       return MakeAvgPool<AvgPool2DAttrs>(data, pool_size, strides, dilation, 
padding, layout,
-                                         ceil_mode, count_include_pad, 
"nn.avg_pool2d");
+                                         out_layout, ceil_mode, 
count_include_pad, "nn.avg_pool2d");
     });
 
 RELAY_REGISTER_OP("nn.avg_pool2d")
@@ -303,9 +309,10 @@ Array<te::Tensor> GlobalPool2DCompute(const Attrs& attrs, 
const Array<te::Tensor
   return Array<te::Tensor>{topi::nn::global_pool(inputs[0], mode, 
layout.name())};
 }
 
-Expr MakeGlobalAvgPool2D(Expr data, String layout) {
+Expr MakeGlobalAvgPool2D(Expr data, String layout, String out_layout) {
   auto attrs = make_object<GlobalPool2DAttrs>();
   attrs->layout = std::move(layout);
+  attrs->out_layout = std::move(out_layout);
   static const Op& op = Op::Get("nn.global_avg_pool2d");
   return Call(op, {data}, Attrs(attrs), {});
 }
@@ -331,9 +338,10 @@ RELAY_REGISTER_OP("nn.global_avg_pool2d")
     .set_attr<FTVMCompute>("FTVMCompute", 
GlobalPool2DCompute<topi::nn::kAvgPool>);
 
 // GlobalMaxPool
-Expr MakeGlobalMaxPool2D(Expr data, String layout) {
+Expr MakeGlobalMaxPool2D(Expr data, String layout, String out_layout) {
   auto attrs = make_object<GlobalPool2DAttrs>();
   attrs->layout = std::move(layout);
+  attrs->out_layout = std::move(out_layout);
   static const Op& op = Op::Get("nn.global_max_pool2d");
   return Call(op, {data}, Attrs(attrs), {});
 }
@@ -423,10 +431,12 @@ Array<te::Tensor> AdaptivePool1DCompute(const Attrs& 
attrs, const Array<te::Tens
 }
 
 // relay.nn.adaptive_avg_pool1d
-Expr MakeAdaptiveAvgPool1D(Expr data, Array<IndexExpr> output_size, String 
layout) {
+Expr MakeAdaptiveAvgPool1D(Expr data, Array<IndexExpr> output_size, String 
layout,
+                           String out_layout) {
   auto attrs = make_object<AdaptivePool1DAttrs>();
   attrs->output_size = std::move(output_size);
   attrs->layout = std::move(layout);
+  attrs->out_layout = std::move(out_layout);
   static const Op& op = Op::Get("nn.adaptive_avg_pool1d");
   return Call(op, {data}, Attrs(attrs), {});
 }
@@ -456,10 +466,12 @@ RELAY_REGISTER_OP("nn.adaptive_avg_pool1d")
     .set_attr<FTVMCompute>("FTVMCompute", 
AdaptivePool1DCompute<topi::nn::kAvgPool>);
 
 // relay.nn.adaptive_max_pool1d
-Expr MakeAdaptiveMaxPool1D(Expr data, Array<IndexExpr> output_size, String 
layout) {
+Expr MakeAdaptiveMaxPool1D(Expr data, Array<IndexExpr> output_size, String 
layout,
+                           String out_layout) {
   auto attrs = make_object<AdaptivePool1DAttrs>();
   attrs->output_size = std::move(output_size);
   attrs->layout = std::move(layout);
+  attrs->out_layout = std::move(out_layout);
   static const Op& op = Op::Get("nn.adaptive_max_pool1d");
   return Call(op, {data}, Attrs(attrs), {});
 }
@@ -571,10 +583,12 @@ Array<te::Tensor> AdaptivePool2DCompute(const Attrs& 
attrs, const Array<te::Tens
 }
 
 // relay.nn.adaptive_avg_pool2d
-Expr MakeAdaptiveAvgPool2D(Expr data, Array<IndexExpr> output_size, String 
layout) {
+Expr MakeAdaptiveAvgPool2D(Expr data, Array<IndexExpr> output_size, String 
layout,
+                           String out_layout) {
   auto attrs = make_object<AdaptivePool2DAttrs>();
   attrs->output_size = std::move(output_size);
   attrs->layout = std::move(layout);
+  attrs->out_layout = std::move(out_layout);
   static const Op& op = Op::Get("nn.adaptive_avg_pool2d");
   return Call(op, {data}, Attrs(attrs), {});
 }
@@ -606,10 +620,12 @@ RELAY_REGISTER_OP("nn.adaptive_avg_pool2d")
     .set_attr<FTVMCompute>("FTVMCompute", 
AdaptivePool2DCompute<topi::nn::kAvgPool>);
 
 // relay.nn.adaptive_max_pool2d
-Expr MakeAdaptiveMaxPool2D(Expr data, Array<IndexExpr> output_size, String 
layout) {
+Expr MakeAdaptiveMaxPool2D(Expr data, Array<IndexExpr> output_size, String 
layout,
+                           String out_layout) {
   auto attrs = make_object<AdaptivePool2DAttrs>();
   attrs->output_size = std::move(output_size);
   attrs->layout = std::move(layout);
+  attrs->out_layout = std::move(out_layout);
   static const Op& op = Op::Get("nn.adaptive_max_pool2d");
   return Call(op, {data}, Attrs(attrs), {});
 }
@@ -700,6 +716,7 @@ Array<te::Tensor> AdaptivePool3DCompute(const Attrs& attrs, 
const Array<te::Tens
   const auto* param = attrs.as<AdaptivePool3DAttrs>();
   ICHECK(param != nullptr);
   Layout layout(param->layout);
+  Layout out_layout(param->out_layout);
   ICHECK(tir::BijectiveLayout(layout, kNCDHW).defined())
       << "Adaptive pool3d currently only supports layouts that are convertible 
from NCDHW";
   ICHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1)
@@ -737,10 +754,12 @@ Array<te::Tensor> AdaptivePool3DCompute(const Attrs& 
attrs, const Array<te::Tens
 }
 
 // relay.nn.adaptive_max_pool3d
-Expr MakeAdaptiveMaxPool3D(Expr data, Array<IndexExpr> output_size, String 
layout) {
+Expr MakeAdaptiveMaxPool3D(Expr data, Array<IndexExpr> output_size, String 
layout,
+                           String out_layout) {
   auto attrs = make_object<AdaptivePool3DAttrs>();
   attrs->output_size = std::move(output_size);
   attrs->layout = std::move(layout);
+  attrs->out_layout = std::move(out_layout);
   static const Op& op = Op::Get("nn.adaptive_max_pool3d");
   return Call(op, {data}, Attrs(attrs), {});
 }
@@ -772,10 +791,12 @@ RELAY_REGISTER_OP("nn.adaptive_max_pool3d")
     .set_attr<FTVMCompute>("FTVMCompute", 
AdaptivePool3DCompute<topi::nn::kMaxPool>);
 
 // relay.nn.adaptive_max_pool3d
-Expr MakeAdaptiveAvgPool3D(Expr data, Array<IndexExpr> output_size, String 
layout) {
+Expr MakeAdaptiveAvgPool3D(Expr data, Array<IndexExpr> output_size, String 
layout,
+                           String out_layout) {
   auto attrs = make_object<AdaptivePool3DAttrs>();
   attrs->output_size = std::move(output_size);
   attrs->layout = std::move(layout);
+  attrs->out_layout = std::move(out_layout);
   static const Op& op = Op::Get("nn.adaptive_avg_pool3d");
   return Call(op, {data}, Attrs(attrs), {});
 }
@@ -866,12 +887,13 @@ Array<te::Tensor> Pool2DGradCompute(const Attrs& attrs, 
const Array<te::Tensor>&
 // MaxPool2DGrad
 Expr MakeMaxPool2DGrad(Expr out_grad, Expr data, Array<IndexExpr> pool_size,
                        Array<IndexExpr> strides, Array<IndexExpr> padding, 
String layout,
-                       bool ceil_mode) {
+                       String out_layout, bool ceil_mode) {
   auto attrs = make_object<MaxPool2DAttrs>();
   attrs->pool_size = std::move(pool_size);
   attrs->strides = std::move(strides);
   attrs->padding = std::move(padding);
   attrs->layout = std::move(layout);
+  attrs->out_layout = std::move(out_layout);
   attrs->ceil_mode = ceil_mode;
   static const Op& op = Op::Get("nn.max_pool2d_grad");
   return Call(op, {out_grad, data}, Attrs(attrs), {});
@@ -913,12 +935,13 @@ RELAY_REGISTER_OP("nn.max_pool2d_grad")
 // AvgPool2DGrad
 Expr MakeAvgPool2DGrad(Expr out_grad, Expr data, Array<IndexExpr> pool_size,
                        Array<IndexExpr> strides, Array<IndexExpr> padding, 
String layout,
-                       bool ceil_mode, bool count_include_pad) {
+                       String out_layout, bool ceil_mode, bool 
count_include_pad) {
   auto attrs = make_object<AvgPool2DAttrs>();
   attrs->pool_size = std::move(pool_size);
   attrs->strides = std::move(strides);
   attrs->padding = std::move(padding);
   attrs->layout = std::move(layout);
+  attrs->out_layout = std::move(out_layout);
   attrs->ceil_mode = ceil_mode;
   attrs->count_include_pad = count_include_pad;
   static const Op& op = Op::Get("nn.avg_pool2d_grad");
@@ -976,6 +999,7 @@ bool Pool1DRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
   ICHECK(param != nullptr);
 
   Layout layout(param->layout);
+  Layout out_layout(param->out_layout);
   ICHECK(layout.Contains(LayoutAxis::Get('W')) && 
!layout.Contains(LayoutAxis::Get('w')))
       << "Invalid layout " << layout << ". Pool1D layout must have W, which 
cannot be split";
 
@@ -1018,6 +1042,7 @@ Array<te::Tensor> Pool1DCompute(const Attrs& attrs, const 
Array<te::Tensor>& inp
   auto padding = param->padding;
   auto ceil_mode = param->ceil_mode;
   Layout layout(param->layout);
+  Layout out_layout(param->out_layout);
 
   ICHECK(tir::BijectiveLayout(layout, kNCW).defined())
       << "max_pool1d currently only supports layouts that are convertible from 
NCW";
@@ -1046,9 +1071,9 @@ Array<te::Tensor> Pool1DCompute(const Attrs& attrs, const 
Array<te::Tensor>& inp
 TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool1d")
     .set_body_typed([](Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> 
strides,
                        Array<IndexExpr> dilation, Array<IndexExpr> padding, 
String layout,
-                       bool ceil_mode) {
+                       String out_layout, bool ceil_mode) {
       return MakeMaxPool<MaxPool1DAttrs>(data, pool_size, strides, dilation, 
padding, layout,
-                                         ceil_mode, "nn.max_pool1d");
+                                         out_layout, ceil_mode, 
"nn.max_pool1d");
     });
 
 RELAY_REGISTER_OP("nn.max_pool1d")
@@ -1082,9 +1107,9 @@ RELAY_REGISTER_OP("nn.max_pool1d")
 TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool1d")
     .set_body_typed([](Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> 
strides,
                        Array<IndexExpr> dilation, Array<IndexExpr> padding, 
String layout,
-                       bool ceil_mode, bool count_include_pad) {
+                       String out_layout, bool ceil_mode, bool 
count_include_pad) {
       return MakeAvgPool<AvgPool1DAttrs>(data, pool_size, strides, dilation, 
padding, layout,
-                                         ceil_mode, count_include_pad, 
"nn.avg_pool1d");
+                                         out_layout, ceil_mode, 
count_include_pad, "nn.avg_pool1d");
     });
 
 RELAY_REGISTER_OP("nn.avg_pool1d")
@@ -1134,6 +1159,7 @@ bool Pool3DRel(const Array<Type>& types, int num_inputs, 
const Attrs& attrs,
   ICHECK(param != nullptr);
 
   Layout layout(param->layout);
+  Layout out_layout(param->out_layout);
   ICHECK(layout.Contains(LayoutAxis::Get('D')) && 
layout.Contains(LayoutAxis::Get('H')) &&
          layout.Contains(LayoutAxis::Get('W')) && 
!layout.Contains(LayoutAxis::Get('d')) &&
          !layout.Contains(LayoutAxis::Get('h')) && 
!layout.Contains(LayoutAxis::Get('w')))
@@ -1194,6 +1220,7 @@ Array<te::Tensor> Pool3DCompute(const Attrs& attrs, const 
Array<te::Tensor>& inp
   auto padding = param->padding;
   auto ceil_mode = param->ceil_mode;
   Layout layout(param->layout);
+  Layout out_layout(param->out_layout);
 
   ICHECK(tir::BijectiveLayout(layout, kNCDHW).defined())
       << "max_pool3d currently only supports layouts that are convertible from 
NCDHW";
@@ -1231,9 +1258,9 @@ Array<te::Tensor> Pool3DCompute(const Attrs& attrs, const 
Array<te::Tensor>& inp
 TVM_REGISTER_GLOBAL("relay.op.nn._make.max_pool3d")
     .set_body_typed([](Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> 
strides,
                        Array<IndexExpr> dilation, Array<IndexExpr> padding, 
String layout,
-                       bool ceil_mode) {
+                       String out_layout, bool ceil_mode) {
       return MakeMaxPool<MaxPool3DAttrs>(data, pool_size, strides, dilation, 
padding, layout,
-                                         ceil_mode, "nn.max_pool3d");
+                                         out_layout, ceil_mode, 
"nn.max_pool3d");
     });
 
 RELAY_REGISTER_OP("nn.max_pool3d")
@@ -1270,9 +1297,9 @@ RELAY_REGISTER_OP("nn.max_pool3d")
 TVM_REGISTER_GLOBAL("relay.op.nn._make.avg_pool3d")
     .set_body_typed([](Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> 
strides,
                        Array<IndexExpr> dilation, Array<IndexExpr> padding, 
String layout,
-                       bool ceil_mode, bool count_include_pad) {
+                       String out_layout, bool ceil_mode, bool 
count_include_pad) {
       return MakeAvgPool<AvgPool3DAttrs>(data, pool_size, strides, dilation, 
padding, layout,
-                                         ceil_mode, count_include_pad, 
"nn.avg_pool3d");
+                                         out_layout, ceil_mode, 
count_include_pad, "nn.avg_pool3d");
     });
 
 RELAY_REGISTER_OP("nn.avg_pool3d")
diff --git a/src/relay/op/nn/pooling.h b/src/relay/op/nn/pooling.h
index 9b7eab2..32ae464 100644
--- a/src/relay/op/nn/pooling.h
+++ b/src/relay/op/nn/pooling.h
@@ -35,13 +35,14 @@ namespace relay {
 template <typename T>
 inline Expr MakeMaxPool(Expr data, Array<IndexExpr> pool_size, 
Array<IndexExpr> strides,
                         Array<IndexExpr> dilation, Array<IndexExpr> padding, 
String layout,
-                        bool ceil_mode, String op_name) {
+                        String out_layout, bool ceil_mode, String op_name) {
   auto attrs = make_object<T>();
   attrs->pool_size = std::move(pool_size);
   attrs->strides = std::move(strides);
   attrs->dilation = std::move(dilation);
   attrs->padding = std::move(padding);
   attrs->layout = std::move(layout);
+  attrs->out_layout = std::move(out_layout);
   attrs->ceil_mode = ceil_mode;
   static const Op& op = Op::Get(op_name);
   return Call(op, {data}, Attrs(attrs), {});
@@ -50,13 +51,14 @@ inline Expr MakeMaxPool(Expr data, Array<IndexExpr> 
pool_size, Array<IndexExpr>
 template <typename T>
 inline Expr MakeAvgPool(Expr data, Array<IndexExpr> pool_size, 
Array<IndexExpr> strides,
                         Array<IndexExpr> dilation, Array<IndexExpr> padding, 
String layout,
-                        bool ceil_mode, bool count_include_pad, String 
op_name) {
+                        String out_layout, bool ceil_mode, bool 
count_include_pad, String op_name) {
   auto attrs = make_object<T>();
   attrs->pool_size = std::move(pool_size);
   attrs->strides = std::move(strides);
   attrs->dilation = std::move(dilation);
   attrs->padding = std::move(padding);
   attrs->layout = std::move(layout);
+  attrs->out_layout = std::move(out_layout);
   attrs->ceil_mode = ceil_mode;
   attrs->count_include_pad = count_include_pad;
   static const Op& op = Op::Get(op_name);
diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc
index 5782f1f..ecdd36d 100644
--- a/src/relay/qnn/op/convolution.cc
+++ b/src/relay/qnn/op/convolution.cc
@@ -275,6 +275,7 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, 
const Expr& kernel_zero_
     Array<IndexExpr> padding({0, 0});
     reduced_t2 = AvgPool2D(scaled_hw_t2, param->kernel_size, param->strides, 
param->dilation,
                            padding, param->data_layout,
+                           "",      // out_layout
                            false,   // ceil_mode
                            false);  // count_include_pad
   } else {
@@ -284,6 +285,7 @@ Expr DepthwiseConv2DSecondTerm(const Expr& padded_data, 
const Expr& kernel_zero_
       Array<IndexExpr> padding({0, 0});
       reduced_t2 = AvgPool2D(reduced_t2, param->kernel_size, param->strides, 
param->dilation,
                              padding, param->data_layout,
+                             "",      // out_layout
                              false,   // ceil_mode
                              false);  // count_include_pad
     }
@@ -463,6 +465,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& 
kernel_zero_point,
         Multiply(reduced_c_t2, MakeConstantScalar(DataType::Int(32), kernel_h 
* kernel_w));
     reduced_t2 = AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, 
param->dilation,
                            padding, param->data_layout,
+                           "",      // out_layout
                            false,   // ceil_mode
                            false);  // count_include_pad
   } else {
@@ -471,6 +474,7 @@ Expr Conv2DSecondTerm(const Expr& padded_data, const Expr& 
kernel_zero_point,
     if (stride1 * stride2 != 1) {
       reduced_t2 = AvgPool2D(reduced_c_t2, param->kernel_size, param->strides, 
param->dilation,
                              padding, param->data_layout,
+                             "",      // out_layout
                              false,   // ceil_mode
                              false);  // count_include_pad
     }
diff --git a/src/relay/transforms/pattern_utils.h 
b/src/relay/transforms/pattern_utils.h
index 692ef3c..03b8ee6 100644
--- a/src/relay/transforms/pattern_utils.h
+++ b/src/relay/transforms/pattern_utils.h
@@ -676,9 +676,10 @@ static inline Expr Reshape(Expr data, Array<Integer> 
newshape) {
 
 static inline Expr AvgPool2D(Expr data, Array<IndexExpr> pool_size, 
Array<IndexExpr> strides,
                              Array<IndexExpr> dilation, Array<IndexExpr> 
padding,
-                             std::string layout, bool ceil_mode, bool 
count_include_pad) {
-  return MakeAvgPool<AvgPool2DAttrs>(data, pool_size, strides, dilation, 
padding, layout, ceil_mode,
-                                     count_include_pad, "nn.avg_pool2d");
+                             std::string layout, std::string out_layout, bool 
ceil_mode,
+                             bool count_include_pad) {
+  return MakeAvgPool<AvgPool2DAttrs>(data, pool_size, strides, dilation, 
padding, layout,
+                                     out_layout, ceil_mode, count_include_pad, 
"nn.avg_pool2d");
 }
 
 static inline Expr Pad(Expr data, Array<Array<IndexExpr>> pad_width, Expr 
pad_value,
diff --git a/tests/python/contrib/test_arm_compute_lib/test_pooling.py 
b/tests/python/contrib/test_arm_compute_lib/test_pooling.py
index 9deaa75..b174f9a 100644
--- a/tests/python/contrib/test_arm_compute_lib/test_pooling.py
+++ b/tests/python/contrib/test_arm_compute_lib/test_pooling.py
@@ -123,6 +123,7 @@ def _get_expected_pooling_codegen(
             "num_inputs": "1",
             "num_outputs": "1",
             "layout": [["NHWC"]],
+            "out_layout": [[""]],
             "shape": [[list(output_shape)]],
             "dtype": [[dtype]],
             "padding": [[str(p) for p in padding]],
@@ -149,6 +150,7 @@ def _get_expected_global_pooling_codegen(shape, dtype, 
typef):
             "num_inputs": "1",
             "num_outputs": "1",
             "layout": [["NHWC"]],
+            "out_layout": [[""]],
             "shape": [[[1, 1, 1, shape[3]]]],
             "dtype": [[dtype]],
         },
diff --git a/tests/python/relay/test_pass_convert_op_layout.py 
b/tests/python/relay/test_pass_convert_op_layout.py
index 9b4d154..2359dcd 100644
--- a/tests/python/relay/test_pass_convert_op_layout.py
+++ b/tests/python/relay/test_pass_convert_op_layout.py
@@ -248,6 +248,61 @@ def test_conv_bias_pool_convert_layout():
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
+def test_conv_bias_pool_uses_specified_convert_layout():
+    def before():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        bias = relay.var("bias", shape=(64,))
+        weight = relay.var("weight", shape=(3, 3, 64, 64))
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        y = relay.nn.bias_add(y, bias, axis=3)
+        # a useless tuple, which will be eliminated
+        y = relay.Tuple([y])[0]
+        y = relay.nn.relu(y)
+        y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NHWC")
+        y = relay.cast(y, "int32")
+        y = relay.nn.batch_flatten(y)
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    def expected():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        bias = relay.var("bias", shape=(64,))
+        weight = relay.var("weight", shape=(3, 3, 64, 64))
+        x = relay.layout_transform(x, "NHWC", "NCHW")
+        weight = relay.layout_transform(weight, "HWIO", "OIHW")
+        y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), 
padding=(1, 1))
+
+        bias = relay.expand_dims(bias, axis=0, num_newaxis=3)
+        bias = relay.layout_transform(bias, "NHWC", "NCHW")
+        y = relay.add(y, bias)
+        # a useless tuple, which will be eliminated
+        y = relay.Tuple([y])[0]
+        y = relay.nn.relu(y)
+        y = relay.layout_transform(y, "NCHW", "NHWC")
+        y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NHWC", 
out_layout="NHWC")
+        y = relay.cast(y, "int32")
+        y = relay.nn.batch_flatten(y)
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    a = before()
+    a = run_opt_pass(
+        a,
+        transform.ConvertLayout({"nn.conv2d": ["NCHW", "OIHW"], 
"nn.max_pool2d": ["NHWC"]}),
+    )
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n 
Expected = \n" + str(b)
+
+
 def test_conv_concat_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
@@ -412,6 +467,139 @@ def test_deformable_conv_bias_pool_convert_layout():
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
+def test_deformable_conv_bias_pool_uses_specified_convert_layout():
+    def before(N, CI, H, W, CO, KH, KW, layout):
+        if layout == "NCHW":
+            data_shape = (N, CI, H, W)
+            weight_shape = (CO, CI, KH, KW)
+            kernel_layout = "OIHW"
+        else:
+            data_shape = (N, H, W, CI)
+            weight_shape = (KH, KW, CI, CO)
+            kernel_layout = "HWIO"
+        bias_shape = (CO,)
+
+        data = relay.var("data", shape=data_shape, dtype="float32")
+        offset = relay.var("offset")
+        weight = relay.var("weight", shape=weight_shape, dtype="float32")
+        bias = relay.var("bias", shape=bias_shape, dtype="float32")
+
+        y = relay.nn.deformable_conv2d(
+            data,
+            offset,
+            weight,
+            kernel_size=(KH, KW),
+            channels=CO,
+            data_layout=layout,
+            kernel_layout=kernel_layout,
+        )
+        y = relay.nn.bias_add(y, bias, axis=-1 if layout == "NHWC" else 1)
+        y = relay.nn.relu(y)
+        y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout=layout)
+        y = relay.cast(y, "int32")
+        y = relay.nn.batch_flatten(y)
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    def expected(N, CI, H, W, CO, KH, KW, OH, OW, src_layout, dst_layout, 
max_pool_layout=None):
+        layout_map = {"src": {}, "dst": {}}
+        if src_layout == "NCHW":
+            nchw = layout_map["src"]
+            nhwc = layout_map["dst"]
+        else:
+            nchw = layout_map["dst"]
+            nhwc = layout_map["src"]
+
+        nchw["data_layout"] = "NCHW"
+        nchw["data_shape"] = (N, CI, H, W)
+        nchw["offset_shape"] = (N, KH * KW * 2, OH, OW)
+        nchw["weight_shape"] = (CO, CI, KH, KW)
+        nchw["kernel_layout"] = "OIHW"
+
+        nhwc["data_layout"] = "NHWC"
+        nhwc["data_shape"] = (N, H, W, CI)
+        nhwc["offset_shape"] = (N, OH, OW, KH * KW * 2)
+        nhwc["weight_shape"] = (KH, KW, CI, CO)
+        nhwc["kernel_layout"] = "HWIO"
+
+        bias_shape = (CO,)
+
+        data = relay.var("data", shape=layout_map["src"]["data_shape"], 
dtype="float32")
+        offset = relay.var("offset", shape=layout_map["src"]["offset_shape"], 
dtype="float32")
+        weight = relay.var("weight", shape=layout_map["src"]["weight_shape"], 
dtype="float32")
+        bias = relay.var("bias", shape=bias_shape, dtype="float32")
+
+        data = relay.layout_transform(
+            data, layout_map["src"]["data_layout"], 
layout_map["dst"]["data_layout"]
+        )
+        offset = relay.layout_transform(
+            offset, layout_map["src"]["data_layout"], 
layout_map["dst"]["data_layout"]
+        )
+        weight = relay.layout_transform(
+            weight, layout_map["src"]["kernel_layout"], 
layout_map["dst"]["kernel_layout"]
+        )
+        y = relay.nn.deformable_conv2d(
+            data,
+            offset,
+            weight,
+            kernel_size=(KH, KW),
+            channels=CO,
+            data_layout=layout_map["dst"]["data_layout"],
+            kernel_layout=layout_map["dst"]["kernel_layout"],
+        )
+        if layout_map["src"]["data_layout"] == "NHWC":
+            bias = relay.expand_dims(bias, axis=0, num_newaxis=3)
+        else:
+            bias = relay.expand_dims(bias, axis=1, num_newaxis=2)
+            bias = relay.expand_dims(bias, axis=0)
+        bias = relay.layout_transform(
+            bias, layout_map["src"]["data_layout"], 
layout_map["dst"]["data_layout"]
+        )
+        y = relay.add(y, bias)
+        y = relay.nn.relu(y)
+        if max_pool_layout != layout_map["dst"]["data_layout"]:
+            y = relay.layout_transform(y, layout_map["dst"]["data_layout"], 
max_pool_layout)
+        y = relay.nn.max_pool2d(
+            y, pool_size=(2, 2), layout=max_pool_layout, 
out_layout=max_pool_layout
+        )
+        y = relay.cast(y, "int32")
+        y = relay.nn.batch_flatten(y)
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    # NHWC -> NCHW
+    a = before(1, 3, 224, 224, 32, 3, 3, "NHWC")
+    a = run_opt_pass(
+        a,
+        transform.ConvertLayout(
+            {"nn.deformable_conv2d": ["NCHW", "default"], "nn.max_pool2d": 
["NHWC"]}
+        ),
+    )
+    # - in the before() func, its last argument "NHWC" is also the layout of 
max_pool
+    b = run_opt_pass(
+        # max_pool has its own layout argument
+        expected(1, 3, 224, 224, 32, 3, 3, 222, 222, "NHWC", "NCHW", 
max_pool_layout="NHWC"),
+        transform.InferType(),
+    )
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n 
Expected = \n" + str(b)
+
+    # NCHW -> NHWC
+    a = before(1, 3, 224, 224, 32, 3, 3, "NCHW")
+    a = run_opt_pass(
+        a,
+        transform.ConvertLayout(
+            {"nn.deformable_conv2d": ["NHWC", "default"], "nn.max_pool2d": 
["NCHW"]}
+        ),
+    )
+    # - in the before() func, its last argument "NCHW" is also the layout of 
max_pool
+    b = run_opt_pass(
+        # max_pool has its own layout argument
+        expected(1, 3, 224, 224, 32, 3, 3, 222, 222, "NCHW", "NHWC", 
max_pool_layout="NCHW"),
+        transform.InferType(),
+    )
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n 
Expected = \n" + str(b)
+
+
 def test_dual_path_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
@@ -702,6 +890,57 @@ def test_resnet_convert_layout():
     assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
+def test_resnet_pool_uses_specified_convert_layout():
+    def before():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
+        weight2 = relay.var("weight2", shape=(1, 1, 64, 32))
+        y = relay.nn.conv2d(
+            x,
+            weight1,
+            channels=32,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="HWIO",
+        )
+        y = relay.nn.relu(y)
+        y2 = relay.nn.conv2d(
+            x, weight2, channels=32, kernel_size=(1, 1), data_layout="NHWC", 
kernel_layout="HWIO"
+        )
+        y2 = relay.nn.relu(y2)
+        y = y + y2
+        y = relay.nn.global_max_pool2d(y, layout="NHWC")
+        return relay.Function(analysis.free_vars(y), y)
+
+    def expected():
+        x = relay.var("x", shape=(1, 56, 56, 64))
+        weight1 = relay.var("weight1", shape=(3, 3, 64, 32))
+        weight2 = relay.var("weight2", shape=(1, 1, 64, 32))
+        weight1 = relay.layout_transform(weight1, "HWIO", "OIHW")
+        weight2 = relay.layout_transform(weight2, "HWIO", "OIHW")
+        x = relay.layout_transform(x, "NHWC", "NCHW")
+        y = relay.nn.conv2d(x, weight1, channels=32, kernel_size=(3, 3), 
padding=(1, 1))
+        y = relay.nn.relu(y)
+        y2 = relay.nn.conv2d(x, weight2, channels=32, kernel_size=(1, 1))
+        y2 = relay.nn.relu(y2)
+        y = y + y2
+        y = relay.layout_transform(y, "NCHW", "NHWC")
+        y = relay.nn.global_max_pool2d(y, layout="NHWC", out_layout="NHWC")
+        return relay.Function(analysis.free_vars(y), y)
+
+    a = before()
+    a = run_opt_pass(
+        a,
+        transform.ConvertLayout(
+            {"nn.conv2d": ["NCHW", "default"], "nn.global_max_pool2d": 
["NHWC"]}
+        ),
+    )
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n 
Expected = \n" + str(b)
+
+
 def test_scalar_convert_layout():
     def before():
         x = relay.var("x", shape=(1, 56, 56, 64))
@@ -2039,5 +2278,54 @@ def test_reduce_op_convert_layout():
         assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
+def test_conv_max_pool_uses_specified_convert_layout():
+    def before():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NCHW",
+            kernel_layout="OIHW",
+        )
+        y = relay.nn.relu(y)
+        y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NCHW")
+        y = relay.nn.batch_flatten(y)
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    def expected():
+        x = relay.var("x", shape=(1, 64, 56, 56))
+        weight = relay.var("weight", shape=(64, 64, 3, 3))
+        x = relay.layout_transform(x, "NCHW", "NHWC")
+        weight = relay.layout_transform(weight, "OIHW", "OHWI")
+        y = relay.nn.conv2d(
+            x,
+            weight,
+            channels=64,
+            kernel_size=(3, 3),
+            padding=(1, 1),
+            data_layout="NHWC",
+            kernel_layout="OHWI",
+        )
+        y = relay.nn.relu(y)
+        y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout="NHWC", 
out_layout="NHWC")
+        y = relay.layout_transform(y, "NHWC", "NCHW")
+        y = relay.nn.batch_flatten(y)
+        y = relay.Function(analysis.free_vars(y), y)
+        return y
+
+    a = before()
+    a = run_opt_pass(
+        a, transform.ConvertLayout({"nn.conv2d": ["NHWC", "OHWI"], 
"nn.max_pool2d": ["NHWC"]})
+    )
+    b = run_opt_pass(expected(), transform.InferType())
+
+    assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + "\n\n 
Expected = \n" + str(b)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to