This is an automated email from the ASF dual-hosted git repository.

anijain2305 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0778afd  Use channels from attrs if possible (#7011)
0778afd is described below

commit 0778afd6d0fb0283fba5d4839f27e2ac548a3284
Author: Trevor Morris <trevm...@amazon.com>
AuthorDate: Tue Dec 1 22:04:43 2020 -0800

    Use channels from attrs if possible (#7011)
---
 src/runtime/contrib/tensorrt/tensorrt_ops.cc | 4 ++++
 tests/python/contrib/test_tensorrt.py        | 5 +++++
 2 files changed, 9 insertions(+)

diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc 
b/src/runtime/contrib/tensorrt/tensorrt_ops.cc
index 057743c..c3ff1c4 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc
@@ -243,6 +243,10 @@ class Conv2DOpConverter : public TensorRTOpConverter {
     auto str_padding = 
params->node.GetAttr<std::vector<std::string>>("padding");
     int groups = 
std::stoi(params->node.GetAttr<std::vector<std::string>>("groups")[0]);
     int channels = weight_shape[0];
+    if (params->node.HasAttr("channels") &&
+        
!params->node.GetAttr<std::vector<std::string>>("channels")[0].empty()) {
+      channels = 
std::stoi(params->node.GetAttr<std::vector<std::string>>("channels")[0]);
+    }
     // TRT conv2d op doesn't support asymmetric padding before 5.1, so we
     // workaround by adding a padding layer before the pooling op.
     nvinfer1::DimsHW prepadding, postpadding;
diff --git a/tests/python/contrib/test_tensorrt.py 
b/tests/python/contrib/test_tensorrt.py
index 10c311a..de98222 100644
--- a/tests/python/contrib/test_tensorrt.py
+++ b/tests/python/contrib/test_tensorrt.py
@@ -352,6 +352,7 @@ def test_conv2d():
         padding=(0, 0),
         strides=(1, 1),
         dilation=(1, 1),
+        channels=None,
     ):
         x = relay.var("x", shape=(x_shape), dtype="float32")
         kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
@@ -363,6 +364,7 @@ def test_conv2d():
             padding=padding,
             strides=strides,
             dilation=dilation,
+            channels=channels,
         )
         f = relay.Function([x, kernel], out)
         return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
@@ -380,6 +382,9 @@ def test_conv2d():
                             dilation=dilation,
                         )
                     )
+    run_and_verify_func(
+        get_graph((1, 3, 16, 16), (3, 8, 7, 7), 3, [2, 2, 3, 3], [2, 2], [1, 
1], 24)
+    )
 
 
 def test_conv2d_nhwc():

Reply via email to