manupa-arm commented on a change in pull request #5320: [BYOC] Prevent 
duplicate outputs in subgraph Tuple
URL: https://github.com/apache/incubator-tvm/pull/5320#discussion_r408819374
 
 

 ##########
 File path: src/relay/transforms/partition_graph.cc
 ##########
 @@ -456,18 +370,111 @@ class Partitioner : public ExprMutator {
   }
 
   /*!
-   * \brief Get the index of the return(output);
-   * this is to be used as tuplegetitem idx
+   * \brief This function is called first time that we encounter a compiler_end
+   * node to create the function for the subgraph.
    */
-  int GetRetIdx(AnnotatedRegion sg, const Expr& arg) {
-    int idx = 0;
-    for (auto arg_ : sg->GetOutputs()) {
-      if (arg == arg_) {
-        return idx;
+  void CreateFunction(AnnotatedRegion region, const CallNode* call) {
+    // Create fields which is a unique list of outputs. Also populate
+    // region_return_indices_ map which maps parent of compiler_end node to
+    // corresponding index in fields.
+    Array<Expr> fields;
+    int i = 0;
+    for (auto ret : region->GetOutputs()) {
+      auto ret_node = Downcast<Call>(ret)->args[0];
+      // Don't duplicate outputs.
+      if (!region_return_indices_.count(region) ||
+          !region_return_indices_[region].count(ret_node)) {
+        auto ret_expr = VisitExpr(ret_node);
+        fields.push_back(ret_expr);
+        region_return_indices_[region][ret_node] = i;
+        i++;
       }
-      idx++;
     }
-    return -1;
+
+    Array<Var> params;
+    Array<Expr> param_expr;
+    std::unordered_map<std::string, runtime::NDArray> params_bind;
+
+    for (auto pair : region_args[region]) {
+      params.push_back(pair.first);
+      if (const auto* cn = pair.second.as<ConstantNode>()) {
+        params_bind[pair.first->name_hint()] = cn->data;
+      } else {
+        param_expr.push_back(pair.second);
+      }
+    }
+
+    Function global_region_func;
+    if (fields.size() == 1) {
+      // If there are only a single output; no need to add a tuple
+      global_region_func =
+          Function(params, fields[0], call->args[0]->checked_type_, {}, 
DictAttrs());
+    } else {
+      auto tuple = Tuple(fields);
+      global_region_func = Function(params, tuple, tuple->checked_type_, {}, 
DictAttrs());
+    }
+
+    std::string target = call->attrs.as<CompilerAttrs>()->compiler;
+    std::string name = target + "_" + std::to_string(region->GetID());
+
+    global_region_func = WithAttr(std::move(global_region_func), 
tvm::attr::kGlobalSymbol,
+                                  runtime::String(name));
+    global_region_func =
+        WithAttr(std::move(global_region_func), attr::kPrimitive, 
tvm::Integer(1));
+    global_region_func = WithAttr(std::move(global_region_func), 
attr::kCompiler,
+                                  tvm::runtime::String(target));
+    global_region_func =
+        WithAttr(std::move(global_region_func), attr::kInline, 
tvm::Integer(1));
+
+    // Constant propagation
+    if (!params_bind.empty()) {
+      global_region_func = backend::BindParamsByName(global_region_func, 
params_bind);
+    }
+
+    std::string fname = name;
+    CHECK(!module_->ContainGlobalVar(fname))
+        << "Global function " << fname << " already exists";
+    // Create a global function and add it to the IRModule for the region.
+    // This way we lift the functions that should be handled by external
+    // codegen to the module scope and rely on the pass manager to prevent
+    // relay function level passes (i.e. simplify inference and fusion)
+    // optimizing it.
+    GlobalVar glob_func(fname);
+    module_->Add(glob_func, global_region_func);
+
+    // The return type of callnode is the same as the type of the
+    // compiler_end node.
+    auto ret = Call(glob_func, param_expr);
+    region_function_calls[region] = ret;
+  }
+
+  /*!
+   * \brief Get the return(output) of the function for compiler end node 
"end_arg".
+   * This will return either a Call (for a function with a single output) or a
+   * TupleGetItem (for a function with multiple outputs).
+   */
+  Expr GetFunctionOutput(AnnotatedRegion region, const Expr& end_arg) {
+    Expr arg = Downcast<Call>(end_arg)->args[0];
+    // Function has one output.
+    if (region_return_indices_[region].size() == 1) {
+      return region_function_calls[region];
+    }
+    // Function has multiple outputs.
+    // Use already made TupleGetItem.
+    if (region_return_tuplegetitem_.count(region) &&
+        region_return_tuplegetitem_[region].count(arg)) {
+      return region_return_tuplegetitem_[region][arg];
+    }
+    // Create new TupleGetItem.
+    CHECK(region_return_indices_.count(region) &&
+          region_return_indices_[region].count(arg));
+    int index = region_return_indices_[region][arg];
+
+    auto func_call = region_function_calls[region];
+    auto tuple_get_item_ = TupleGetItem(func_call, index);
+    tuple_get_item_->checked_type_ = arg->checked_type_;
+    region_return_tuplegetitem_[region][arg] = tuple_get_item_;
+    return tuple_get_item_;
 
 Review comment:
   need to std::move this?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to