This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 0d8c9cef72 [Relay] Extend split for blocked ConvertLayout pass (#12886) 0d8c9cef72 is described below commit 0d8c9cef7212e62c18814f1632613fb04de6d290 Author: Andrey Malyshev <elvin.n...@gmail.com> AuthorDate: Thu Sep 29 16:50:59 2022 +0400 [Relay] Extend split for blocked ConvertLayout pass (#12886) * [Relay] Extend split for blocked ConvertLayout pass * Fix lint hits * Fix spelling --- src/relay/op/tensor/transform.cc | 24 ++++++++++- tests/python/relay/test_pass_convert_op_layout.py | 49 +++++++++++++++++++++++ 2 files changed, 72 insertions(+), 1 deletion(-) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index deb05e8877..985222307a 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2982,10 +2982,32 @@ InferCorrectLayoutOutput SplitInferCorrectLayout(const Attrs& attrs, // If new_in_layouts are defined, this code tries to modify the layout. if (new_in_layouts.defined() && old_in_layouts.defined()) { + bool divisible = true; const auto& sp_dim = old_in_layouts[0][axis]; auto new_index = new_in_layouts[0].IndexOf(sp_dim); param->axis = new_index; - ret = new_in_layouts[0]; + int factor = new_in_layouts[0].FactorOf(sp_dim); + if (factor > 1) { + if (!param->indices_or_sections.as<IntImmNode>()) { + auto ios = Downcast<Array<Integer>>(param->indices_or_sections); + Array<Integer> new_ios; + for (const auto& v : ios) { + const IntImmNode* vint = v.as<IntImmNode>(); + new_ios.push_back(vint->value / factor); + if (vint->value % factor) { + divisible = false; + } + } + if (divisible) { + param->indices_or_sections = new_ios; + } + } + } + if (divisible) { + ret = new_in_layouts[0]; + } else { + ret = old_in_layouts[0]; + } } else if (old_in_layouts.defined()) { ret = old_in_layouts[0]; } diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index 3d5af83b8c..223926a877 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -1760,9 +1760,58 @@ def test_conv_split_convert_layout(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + def _test_conv_split_convert_layout_blocking(): + def before(): + x = relay.var("x", shape=(1, 512, 38, 38)) + weight = relay.var("weight", shape=(512, 512, 3, 3)) + y = relay.nn.conv2d( + x, + weight, + channels=512, + kernel_size=(3, 3), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = relay.nn.relu(y) + y = relay.op.split(y, indices_or_sections=[256], axis=1).astuple() + a = relay.TupleGetItem(y, 0) + b = relay.TupleGetItem(y, 1) + out = relay.Tuple([a, b]) + return relay.Function(analysis.free_vars(out), out) + + def expected(): + x = relay.var("x", shape=(1, 512, 38, 38)) + weight = relay.var("weight", shape=(512, 512, 3, 3)) + weight = relay.layout_transform(weight, "OIHW", "OIHW4o") + x = relay.layout_transform(x, "NCHW", "NCHW4c") + y = relay.op.nn.contrib_conv2d_nchwc( + x, + weight, + channels=512, + kernel_size=(3, 3), + padding=(0, 0), + data_layout="NCHW4c", + kernel_layout="OIHW4o", + ) + y = relay.nn.relu(y) + y = relay.op.split(y, indices_or_sections=[64], axis=1).astuple() + a = relay.TupleGetItem(y, 0) + b = relay.TupleGetItem(y, 1) + a = relay.layout_transform(a, "NCHW4c", "NCHW") + b = relay.layout_transform(b, "NCHW4c", "NCHW") + out = relay.Tuple([a, b]) + return relay.Function(analysis.free_vars(out), out) + + a = before() + a = run_opt_pass(a, transform.ConvertLayout({"nn.conv2d": ["NCHW4c", "OIHW4o"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + _test_conv_split_convert_layout1() _test_conv_split_convert_layout2() _test_conv_split_convert_layout3() + _test_conv_split_convert_layout_blocking() def test_conv_strided_slice_axes_convert_layout():