xinyu-intel commented on a change in pull request #15910: [Quantization]support exclude operators while quantization URL: https://github.com/apache/incubator-mxnet/pull/15910#discussion_r321569610
########## File path: src/operator/quantization/quantize_graph_pass.cc ########## @@ -102,28 +102,57 @@ std::vector<NodeEntry> OfflineParams(std::vector<NodeEntry>&& outputs, return outputs; } -inline NodePtr NeedQuantize(NodePtr node, const std::unordered_set<std::string>& excluded_nodes) { +// To check if a node is registered with a computation function on a target device. +bool isRegistered(NodePtr node, const int& dev_type) { + const auto& op = node->op(); + Context ctx = Context::Create(static_cast<Context::DeviceType>(dev_type), 0); + FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute", ctx); + FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx", ctx); + FStatefulCompute fcomputestateful = + common::GetFCompute<FStatefulCompute>(op, "FStatefulCompute", ctx); + FStatefulComputeEx fcomputestateful_ex = + common::GetFCompute<FStatefulComputeEx>(op, "FStatefulComputeEx", ctx); + return (fcompute != nullptr || fcomp_ex != nullptr || + fcomputestateful != nullptr || fcomputestateful_ex != nullptr); +} + +inline NodePtr NeedQuantize( + NodePtr node, const std::unordered_set<std::string>& excluded_nodes, + const std::unordered_set<std::string>& excluded_ops, + const int& dev_type) { std::unordered_map<NodePtr, NodePtr> quantized_node; static auto& quantized_op_map = Op::GetAttr<mxnet::FQuantizedOp>("FQuantizedOp"); static auto& fexec_type = nnvm::Op::GetAttr<FExecType>("FExecType"); const auto& op = node->op(); if (op && quantized_op_map.count(op)) { bool need = true; - if (excluded_nodes.count(node->attrs.name)) { + // If the quantized node is not registered with a computation function, the node + // will be excluded automatically. + auto q_ptr = quantized_op_map[node->op()]; + auto qnode = q_ptr(node->attrs); + if (!isRegistered(qnode, dev_type)) { + LOG(INFO) << "Neither FCompute nor FComputeEx registered, " << node->op()->name + << " excluded automatically."; need = false; - } else if (!node->attrs.subgraphs.empty()) { - ExecType exec_type = fexec_type.count(op) ? fexec_type[op](node->attrs) : ExecType::kSync; - if (exec_type != ExecType::kSubgraphExec) { - // This is a fused subgraph node, try to match inner node. - CHECK_EQ(node->attrs.subgraphs.size(), 1); - auto subgraph_sym = node->attrs.subgraphs[0]; - DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr& n) { - if (n->is_variable()) return; - if (excluded_nodes.count(n->attrs.name)) { - need = false; - } - }); + } else { + if (excluded_nodes.count(node->attrs.name) || + excluded_ops.count(node->op()->name)) { + need = false; + } else if (!node->attrs.subgraphs.empty()) { + ExecType exec_type = fexec_type.count(op) ? fexec_type[op](node->attrs) : ExecType::kSync; + if (exec_type != ExecType::kSubgraphExec) { + // This is a fused subgraph node, try to match inner node. + CHECK_EQ(node->attrs.subgraphs.size(), 1); + auto subgraph_sym = node->attrs.subgraphs[0]; + DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr& n) { + if (n->is_variable()) return; + if (excluded_nodes.count(n->attrs.name) || + excluded_ops.count(node->op()->name)) { Review comment: I found we cannot exclude fused conv layers when setting`excluded_op_names=['Convolution']`. Is it necessary to check the inner node here? ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services