KellenSunderland commented on a change in pull request #15399: Add unit tests 
for TensorRT integration and fix some bugs
URL: https://github.com/apache/incubator-mxnet/pull/15399#discussion_r311301113
 
 

 ##########
 File path: src/operator/subgraph/tensorrt/tensorrt-inl.h
 ##########
 @@ -109,13 +111,70 @@ class TensorrtSelector : public SubgraphSelector {
 
   bool isTRTCompatible(const nnvm::Node &n) {
     const std::string op_name = n.op()->name;
+    if (op_name == "FullyConnected") {
+      const auto& param = nnvm::get<FullyConnectedParam>(n.attrs.parsed);
+      return !param.no_bias;
+    }
+
     if (op_name == "Pooling") {
-      return (n.attrs.dict.at("pool_type") == "avg" ||
-          n.attrs.dict.at("pool_type") == "max");
+      const auto& param = nnvm::get<PoolingParam>(n.attrs.parsed);
+      if (param.layout.has_value()) {
+        if (param.layout.value() == mshadow::kNHWC) {
+          LOG(INFO) << "Warning: NHWC layout (node: " << n.attrs.name
+                    << ") is not supported by TensorRT";
+          return false;
+        } else if (param.layout.value() == mshadow::kNDHWC) {
+          LOG(INFO) << "Warning: NDHWC layout (node: " << n.attrs.name
+                    << ") is not supported by TensorRT";
+          return false;
+        }
+      }
+      if (param.pooling_convention != pool_enum::kValid && !param.global_pool)
+        return false;
+      if (param.pool_type == pool_enum::kAvgPooling) {
+        if ((!param.global_pool) &&
+            (!param.count_include_pad.has_value() || 
param.count_include_pad.value()))
+          return false;
+        return true;
+      } else if (param.pool_type == pool_enum::kMaxPooling) {
+        return true;
+      } else {
+        return false;
+      }
     }
 
-    if (unconditionalTRTops.count(op_name)) {
-      return true;
+    if (op_name == "Convolution") {
+      const auto& param = nnvm::get<ConvolutionParam>(n.attrs.parsed);
+      if (!param.layout.has_value())
+        return true;
+      switch (param.layout.value()) {
+        case mshadow::kNCHW:
+        case mshadow::kNCW:
+        case mshadow::kNCDHW:
+          return true;
+        case mshadow::kNHWC:
+          LOG(INFO) << "Warning: NHWC layout (node: " << n.attrs.name
+                    << ") is not supported by TensorRT";
+          return false;
+        case mshadow::kNDHWC:
+          LOG(INFO) << "Warning: NDHWC layout (node: " << n.attrs.name
+                    << ") is not supported by TensorRT";
+          return false;
+        default:
+          LOG(INFO) << "Warning: Layout (node: " << n.attrs.name
+                    << ") is unknown (so unsupported by TensorRT)";
+          return false;
+      }
+    }
+
+    if (op_name == "Concat") {
+      const auto& param = nnvm::get<ConcatParam>(n.attrs.parsed);
+      return (param.dim != 0);
+    }
+
+    if (op_name == "Dropout") {
 
 Review comment:
   Ok, non-blocking comment for this PR.  I'm just thinking about adding a 
warning in the future if people are using TRT with operations that don't make 
sense at inference time (Dropout, Ident, Empty Concats or Copies, etc.)

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to