anirudh2290 commented on a change in pull request #15118: Conversion from FP32 
model to Mixed Precision model
URL: https://github.com/apache/incubator-mxnet/pull/15118#discussion_r295600932
 
 

 ##########
 File path: src/nnvm/low_precision_pass.cc
 ##########
 @@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ *  Copyright (c) 2016 by Contributors
+ * \file low_precision_pass.cc
+ * \brief Return new graph with amp_cast and amp_multicast operators added 
wherever required
+ */
+
+#include <nnvm/node.h>
+#include <nnvm/graph.h>
+#include <nnvm/pass.h>
+#include <nnvm/op_attr_types.h>
+#include <mxnet/base.h>
+#include <algorithm>
+#include <functional>
+
+namespace mxnet {
+using nnvm::Symbol;
+using nnvm::Node;
+using nnvm::NodePtr;
+using nnvm::NodeEntry;
+using nnvm::Graph;
+
+// create a node for operator : op_name with name : node_name
+NodePtr CreateNode(std::string op_name, std::string node_name) {
+  NodePtr node = Node::Create();
+  node->attrs.name = node_name;
+  if (op_name == "nullptr") {
+    node->attrs.op = nullptr;
+    // ugly workaround because VariableParam is not exposed
+    node->attrs.parsed = nnvm::Symbol::CreateVariable(node->attrs.name)
+                             .outputs[0]
+                             .node->attrs.parsed;
+  } else {
+    node->attrs.op = Op::Get(op_name);
+  }
+  return node;
+}
+
+NodePtr InsertNode(std::string op_name, std::string node_name, NodePtr current,
+                   NodeEntry previous) {
+    NodePtr node = CreateNode(op_name, node_name);
+    node->inputs.emplace_back(previous);
+    current->inputs.emplace_back(NodeEntry{node, 0, 0});
+    return node;
+}
+
+// get suffix for a node entry so that it can be used for 
amp_cast/amp_multicast node name
+std::string GetSuffix(const nnvm::NodeEntry &e,
+                      const std::unordered_map<Node*, NodePtr> &mirror_map) {
+  static const auto &flist_outputs =
+      nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListOutputNames");
+  std::string suffix = "";
+  NodePtr mirror_node = mirror_map.at(e.node.get());
+  if (mirror_node->op() != nullptr) {
+      auto list_output_names_func = flist_outputs.get(e.node->op(), nullptr);
+      if (list_output_names_func != nullptr) {
+          std::vector<std::string> names = 
list_output_names_func(e.node->attrs);
+          suffix = "_" + names[e.index];
+      } else {
+          suffix = "_" + std::to_string(e.index);
+      }
+  }
+  return suffix;
+}
+
+// add amp_cast node between curr_node and input
+void AddCastNode(const nnvm::NodeEntry &e, const std::string &suffix,
+                 const nnvm::NodeEntry &input, const std::string dtype,
+                 nnvm::NodeEntryMap<NodeEntry> *mirror_entry_map,
+                 NodePtr curr_node) {
+  NodePtr cast_node =
+      InsertNode("amp_cast", e.node->attrs.name + suffix + "_amp_cast_" + 
dtype,
+                 curr_node, input);
+  cast_node->attrs.dict["dtype"] = dtype;
+  cast_node->op()->attr_parser(&(cast_node->attrs));
+  (*mirror_entry_map)[e] = NodeEntry{cast_node, 0, e.version};
+  return;
+}
+
+// add amp_multicast node between curr_node and inputs
+void AddMultiCastNode(const std::vector<NodeEntry> &inputs,
+                      const std::string &node_name,
+                      const std::unordered_map<Node *, NodePtr> &mirror_map,
+                      NodePtr curr_node) {
+    NodePtr node = CreateNode("amp_multicast",
+                              inputs[0].node->attrs.name + node_name + 
"_amp_multicast");
+    for (size_t i = 0; i < inputs.size(); ++i) {
+    const auto &e = inputs[i];
+    NodePtr mirror_node = mirror_map.at(e.node.get());
+    NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version};
+    node->inputs.emplace_back(mirror_entry);
+    }
+    node->attrs.dict["num_outputs"] = std::to_string(inputs.size());
+    node->op()->attr_parser(&(node->attrs));
+    for (uint32_t i = 0; i < inputs.size(); ++i) {
+    const auto &e = inputs[i];
+    curr_node->inputs.emplace_back(
+        NodeEntry{node, static_cast<uint32_t>(i), e.version});
+    }
+    return;
+}
+
+bool CheckConditionalFP32(
+    const std::unordered_map<
+        std::string, std::unordered_map<std::string, std::vector<std::string>>>
+        &conditional_fp32_ops,
+    const std::unordered_set<std::string> &excluded_syms, NodePtr node) {
+  if (node->is_variable() || (excluded_syms.count(node->attrs.name) > 0) ||
+      conditional_fp32_ops.count(node->op()->name) == 0) {
+    return false;
+  } else {
+    // Iterate through all conditional ops
+    auto it = conditional_fp32_ops.find(node->op()->name);
+    if (it != conditional_fp32_ops.end()) {
+      auto it_params = it->second;
+      // For each param name, iterate through param values to check
+      // if the provided param name is equal to any of the values
+      for (auto it_param = it_params.begin(); it_param != it_params.end();
+           it_param++) {
+        auto param_key = node->attrs.dict.find(it_param->first);
+        if (param_key != node->attrs.dict.end()) {
+          auto it_param_vals = it_param->second;
+          if (std::find(it_param_vals.begin(), it_param_vals.end(),
+                        param_key->second) != it_param_vals.end()) {
+            return true;
+          }
+        }
+      }
+    }
+    return false;
+  }
+}
+
+Graph ReducePrecision(Graph &&src) {
+  const auto target_dtype_ops =
+      src.GetAttr<std::unordered_set<std::string>>("target_dtype_ops");
+  const auto fp32_ops =
+      src.GetAttr<std::unordered_set<std::string>>("fp32_ops");
+  const auto widest_dtype_ops =
+      src.GetAttr<std::unordered_set<std::string>>("widest_dtype_ops");
+  const auto target_dtype = src.GetAttr<int>("target_dtype");
+  const auto excluded_syms = 
src.GetAttr<std::unordered_set<std::string>>("excluded_syms");
+  const auto conditional_fp32_ops = src.GetAttr<std::unordered_map<
+      std::string, std::unordered_map<std::string, std::vector<std::string>>>>(
+      "conditional_fp32_ops");
+
+  CHECK(target_dtype == mshadow::kFloat16)
+      << "Only float16 target_dtype is supported yet";
+
+  // Additional data structures to share common cast node inputs among 
different nodes
+  std::unordered_map<Node *, NodePtr> mirror_map;
+  nnvm::NodeEntryMap<NodeEntry> mirror_fp32_map;
+  nnvm::NodeEntryMap<NodeEntry> mirror_target_dtype_map;
+
+  // Visit nodes in a topologically sorted order
+  DFSVisit(src.outputs, [&](const NodePtr &node) {
+    NodePtr new_node = Node::Create();
 
 Review comment:
   yes you are right. my editor was getting confused by num of arguments 
probably because of std::forward. i have fixed this now.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to