This is an automated email from the ASF dual-hosted git repository. comaniac pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 11379f7 Fix reduce NCHWc infer layout (do not keep reduced inner c when keepdims=false) (#9821) 11379f7 is described below commit 11379f710bf9bebf4a7a0cf6c0943899047d11ed Author: masahi <masahi...@gmail.com> AuthorDate: Tue Jan 4 02:32:36 2022 +0900 Fix reduce NCHWc infer layout (do not keep reduced inner c when keepdims=false) (#9821) * Fix reduce NCHWc infer layout (do not keep reduced inner c when keepdims=false) * black * lint --- src/relay/op/tensor/reduce.cc | 2 +- tests/python/relay/test_pass_alter_op_layout.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/relay/op/tensor/reduce.cc b/src/relay/op/tensor/reduce.cc index 5001925..d844bb5 100644 --- a/src/relay/op/tensor/reduce.cc +++ b/src/relay/op/tensor/reduce.cc @@ -176,7 +176,7 @@ InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs, if (params->exclude) { // The primal axis is not reduced, so keep the input packed dim. inferred_out_string += packed_dim; - } else { + } else if (params->keepdims) { // If the primal axis is part of reduce axes in the original layout, the inner dim // becomes 1 after reduction. inferred_out_string += "1" + layout_dim; diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 7514a93..ea7fe0b 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -24,6 +24,7 @@ from tvm.relay.testing.temp_op_attr import TempOpAttr from tvm.relay.testing import run_infer_type import numpy as np import tvm.testing +from tvm.relay import testing def run_opt_pass(expr, passes): @@ -1452,5 +1453,23 @@ def test_conv2d_strided_slice_packed_to_unpacked(): assert tvm.ir.structural_equal(a, b) +def test_conv2d_reduce_channels(): + x = relay.var("data", shape=(1, 8, 48, 48)) + y = relay.nn.conv2d( + data=x, + weight=relay.var("weight"), + kernel_size=(1, 1), + channels=8, + dilation=1, + strides=(47, 47), + ) + z = relay.argmin(y, axis=1) + + mod, params = testing.create_workload(z) + + with tvm.transform.PassContext(opt_level=3): + relay.build(mod, params=params, target="llvm") + + if __name__ == "__main__": pytest.main([__file__])