This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 0f4c065 [Relay][Convert Layout] Enable layout transformation for image.resize op (#8205) 0f4c065 is described below commit 0f4c0654ef94c2252d0075e726b2c6589430d9d7 Author: Jorn Tuyls <jtu...@users.noreply.github.com> AuthorDate: Fri Jun 18 08:56:49 2021 +0200 [Relay][Convert Layout] Enable layout transformation for image.resize op (#8205) * Enable layout transformation for image.resize op * Change str map function to str and index retrieval * Fix for pytorch frontend segmentation models test --- python/tvm/relay/op/image/_image.py | 31 ++++++++ src/relay/op/image/resize.cc | 26 +++++++ tests/python/relay/test_pass_convert_op_layout.py | 86 +++++++++++++++++++++++ 3 files changed, 143 insertions(+) diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py index 5b7fd32..2071a43 100644 --- a/python/tvm/relay/op/image/_image.py +++ b/python/tvm/relay/op/image/_image.py @@ -26,6 +26,7 @@ from tvm.topi.utils import get_const_tuple from .. import op as reg from .. import strategy from ..op import OpPattern +from .image import resize # resize @@ -58,6 +59,36 @@ def compute_resize(attrs, inputs, out_type): reg.register_injective_schedule("image.resize") +@reg.register_convert_op_layout("image.resize") +def convert_image_resize(attrs, inputs, tinfos, desired_layouts): + """Convert Layout pass registration for image resize op. + + Parameters + ---------- + attrs : tvm.ir.Attrs + Attributes of current resize op + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layouts : list of layout strings + List of layouts defining our desired + layout for the data input. + + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + + new_attrs = dict(attrs) + assert len(desired_layouts) == 1, "Only one desired layout is expected" + desired_layout = str(desired_layouts[0]) + assert desired_layout != "default", "Layout cannot be default" + new_attrs["layout"] = desired_layout + return resize(*inputs, **new_attrs) + + @script def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis): out = output_tensor((4,), "int64") diff --git a/src/relay/op/image/resize.cc b/src/relay/op/image/resize.cc index 9c3d601..2c90d7b 100644 --- a/src/relay/op/image/resize.cc +++ b/src/relay/op/image/resize.cc @@ -33,6 +33,31 @@ namespace relay { TVM_REGISTER_NODE_TYPE(ResizeAttrs); +template <typename T> +Array<Array<Layout> > ResizeInferCorrectLayout(const Attrs& attrs, + const Array<Layout>& new_in_layouts, + const Array<Layout>& old_in_layouts, + const Array<tvm::relay::Type>& old_in_types) { + // NOTE: Discard "const" qualifier here. + T* params = const_cast<T*>(attrs.as<T>()); + + if (new_in_layouts.defined()) { + ICHECK_EQ(new_in_layouts.size(), 1); + + Layout raw_layout(params->layout); + Layout new_layout = new_in_layouts[0]; + Layout old_layout = old_in_layouts[0]; + if (!new_layout.Equals(old_layout) && raw_layout.Equals(old_layout) && + new_layout->axes.size() == old_layout->axes.size()) { + // Follow input layout + params->layout = new_layout.name(); + } + } + + Layout inferred_layout(params->layout); + return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}}; +} + bool ResizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs, const TypeReporter& reporter) { ICHECK_EQ(types.size(), 2); @@ -102,6 +127,7 @@ RELAY_REGISTER_OP("image.resize") .add_argument("data", "Tensor", "The input tensor.") .set_support_level(5) .add_type_rel("Resize", ResizeRel) + .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ResizeInferCorrectLayout<ResizeAttrs>) .set_attr<TOpPattern>("TOpPattern", kInjective); TVM_REGISTER_NODE_TYPE(Resize3dAttrs); diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 4710d50..88590c9 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1797,6 +1797,90 @@ def test_conv_reduce_convert_layout(): _test_conv_reduce_convert_layout2() +def test_image_resize_convert_layout(): + def _test_image_resize_convert_layout_nchw_to_nhwc(): + def before(): + x = relay.var("x", shape=(1, 2, 4, 4)) + y = relay.image.resize(x, (8, 8)) + y = relay.Function([x], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 2, 4, 4)) + x = relay.layout_transform(x, "NCHW", "NHWC") + y = relay.image.resize(x, (8, 8), layout="NHWC") + y = relay.layout_transform(y, "NHWC", "NCHW") + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NHWC"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def _test_image_resize_convert_layout_nhwc_to_nchw(): + def before(): + x = relay.var("x", shape=(1, 4, 4, 2)) + y = relay.image.resize(x, (8, 8), layout="NHWC") + y = relay.Function([x], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 4, 4, 2)) + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.image.resize(x, (8, 8), layout="NCHW") + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"image.resize": ["NCHW"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + _test_image_resize_convert_layout_nchw_to_nhwc() + _test_image_resize_convert_layout_nhwc_to_nchw() + + +def test_conv_image_resize_convert_layout(): + """Check that layout transforms are propagated through image resize.""" + + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d( + x, + weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.image.resize(y, (112, 112), layout="NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + w = relay.var("weight", shape=(3, 3, 64, 64)) + x = relay.layout_transform(x, "NHWC", "NCHW") + w = relay.layout_transform(w, "HWIO", "OIHW") + y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), padding=(1, 1)) + y = relay.image.resize(y, (112, 112), layout="NCHW") + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + if __name__ == "__main__": test_qnn_binary_no_convert_layout() test_no_convert_layout() @@ -1828,3 +1912,5 @@ if __name__ == "__main__": test_conv_squeeze_convert_layout() test_conv_reduce_convert_layout() test_conv_strided_slice_axes_convert_layout() + test_image_resize_convert_layout() + test_conv_image_resize_convert_layout()