This is an automated email from the ASF dual-hosted git repository. anijain2305 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 0778afd Use channels from attrs if possible (#7011) 0778afd is described below commit 0778afd6d0fb0283fba5d4839f27e2ac548a3284 Author: Trevor Morris <trevm...@amazon.com> AuthorDate: Tue Dec 1 22:04:43 2020 -0800 Use channels from attrs if possible (#7011) --- src/runtime/contrib/tensorrt/tensorrt_ops.cc | 4 ++++ tests/python/contrib/test_tensorrt.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc b/src/runtime/contrib/tensorrt/tensorrt_ops.cc index 057743c..c3ff1c4 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc @@ -243,6 +243,10 @@ class Conv2DOpConverter : public TensorRTOpConverter { auto str_padding = params->node.GetAttr<std::vector<std::string>>("padding"); int groups = std::stoi(params->node.GetAttr<std::vector<std::string>>("groups")[0]); int channels = weight_shape[0]; + if (params->node.HasAttr("channels") && + !params->node.GetAttr<std::vector<std::string>>("channels")[0].empty()) { + channels = std::stoi(params->node.GetAttr<std::vector<std::string>>("channels")[0]); + } // TRT conv2d op doesn't support asymmetric padding before 5.1, so we // workaround by adding a padding layer before the pooling op. nvinfer1::DimsHW prepadding, postpadding; diff --git a/tests/python/contrib/test_tensorrt.py b/tests/python/contrib/test_tensorrt.py index 10c311a..de98222 100644 --- a/tests/python/contrib/test_tensorrt.py +++ b/tests/python/contrib/test_tensorrt.py @@ -352,6 +352,7 @@ def test_conv2d(): padding=(0, 0), strides=(1, 1), dilation=(1, 1), + channels=None, ): x = relay.var("x", shape=(x_shape), dtype="float32") kernel = relay.var("kernel", shape=(k_shape), dtype="float32") @@ -363,6 +364,7 @@ def test_conv2d(): padding=padding, strides=strides, dilation=dilation, + channels=channels, ) f = relay.Function([x, kernel], out) return f, {"x": x_shape, "kernel": k_shape}, ["kernel"] @@ -380,6 +382,9 @@ def test_conv2d(): dilation=dilation, ) ) + run_and_verify_func( + get_graph((1, 3, 16, 16), (3, 8, 7, 7), 3, [2, 2, 3, 3], [2, 2], [1, 1], 24) + ) def test_conv2d_nhwc():