[GitHub] [incubator-tvm] mbaret commented on a change in pull request #5134: [RELAY] Add MergeCompilerRegions pass

2020-03-27 Thread GitBox
mbaret commented on a change in pull request #5134: [RELAY] Add 
MergeCompilerRegions pass
URL: https://github.com/apache/incubator-tvm/pull/5134#discussion_r399364691
 
 

 ##
 File path: src/relay/transforms/merge_compiler_regions.cc
 ##
 @@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+ * \file src/relay/transforms/merge_compiler_regions.cc
+ *
+ * \brief After operators have been annotated with the targets that support
+ * them, this pass creates regions of the operators for each target. It
+ * is guaranteed that the regions will have a topological ordering so that
+ * no data dependency issues exist.
+ *
+ * This pass only introduces annotations to indicate the regions.
+ * partition_graph must subsequently be called to lift these regions out
+ * as external functions.
+ */
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+#include "../analysis/annotated_region_set.h"
+
+
+namespace tvm {
+namespace relay {
+namespace partitioning {
+
+// Cache compiler_begin and compiler_end annotation ops for equivalence check 
to
+// reduce registry lookup overhead.
+static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
+static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
+
+/*! \brief This is a pre-requisite pass to merge-supported pass.
+ *  The AnnotateRestDefault pass will put "default" Compiler Annotations to
+ *  nodes that are not annotated already. This is there to ensure that the
+ *  user will not leave un-annotated nodes MergeCompilerRegions pass is run.
+ *  Why? Because, MergeCompilerRegions pass assumes every node to be annotated.
+ */
+class AnnotateRestDefault : public ExprMutator {
+ public:
+  explicit AnnotateRestDefault(const Expr& expr) {
+  regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, 
compiler_end_op);
+  }
+
+  Expr Annotate(const Expr& expr) {
+// Its a function that is being passed on to annotate
+func_ = Downcast(expr);
+
+// Corner Case CC1 : If the last node does not belong
+// to a region nede to add a compiler_end
+auto region = regions_->GetRegion(func_->body);
+auto mutated_expr = this->VisitExpr(expr);
+if (!region.defined()) {
+  func_ = Downcast(mutated_expr);
+  // CC1 : add that compiler end after mutation
+  auto body = AddCompilerEnd_(func_->body);
+  func_ = Function(func_->params, body,
+   body->checked_type_, {}, DictAttrs());
+  return Downcast(func_);
+} else {
+  return mutated_expr;
+}
+  }
+
+  /*! \brief This function adds compiler ends to nodes that
+   * have a region AND they should not be arguments of the
+   * original function
+   * \param expr
+   * \return expr
+   */
+  Expr AddCompilerEnd(const Expr& expr) {
+auto region = regions_->GetRegion(expr);
+auto visited_expr = VisitExpr(expr);
+
+// The compiler ends are added to nodes that does have a region
+// AND they should not be arguments of the original function
+if (!region.defined() &&
+   std::find(func_->params.begin(),
+ func_->params.end(), visited_expr)
+   == func_->params.end()) {
+  return AddCompilerEnd_(visited_expr);
+} else {
+  return visited_expr;
+}
+  }
+
+  Expr AddCompilerEnd_(const Expr& expr) {
+const auto* end_op =
+  runtime::Registry::Get("relay.op.annotation._make.compiler_end");
+CHECK(end_op);
+Expr end = (*end_op)(expr, target_);
+return end;
+  }
+
+  Expr VisitExpr_(const CallNode* call) final {
+auto op_node = call->op.as();
+auto ret = GetRef(call);
+
+Array args;
+
+// Add compiler ends if the parent is supported
+for (auto arg : call->args) {
+  args.push_back(AddCompilerEnd(arg));
+}
+
+if (op_node == nullptr || call->attrs.as() == nullptr) {
+  // Skip annotatation ops, only add default compiler to actual compute 
nodes
+
+  auto region = regions_->GetRegion(ret);
+  if (!region.defined()) {
+// if the current node does not belong to annotated region
+// annotate the all incoming edges (args)
+// with 

[GitHub] [incubator-tvm] mbaret commented on a change in pull request #5134: [RELAY] Add MergeCompilerRegions pass

2020-03-26 Thread GitBox
mbaret commented on a change in pull request #5134: [RELAY] Add 
MergeCompilerRegions pass
URL: https://github.com/apache/incubator-tvm/pull/5134#discussion_r398568581
 
 

 ##
 File path: src/relay/transforms/merge_compiler_regions.cc
 ##
 @@ -0,0 +1,352 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+ * \file src/relay/transforms/merge_compiler_regions.cc
+ *
+ * \brief After operators have been annotated with the targets that support
+ * them, this pass creates regions of the operators for each target. It
+ * is guaranteed that the regions will have a topological ordering so that
+ * no data dependency issues exist.
+ *
+ * This pass only introduces annotations to indicate the regions.
+ * partition_graph must subsequently be called to lift these regions out
+ * as external functions.
+ */
+
+#include 
+#include 
+#include 
+#include 
+#include 
+#include 
+
+#include 
+#include 
+#include 
+#include 
+
+#include "../analysis/annotated_region_set.h"
+
+
+namespace tvm {
+namespace relay {
+namespace partitioning {
+
+// Cache compiler_begin and compiler_end annotation ops for equivalence check 
to
+// reduce registry lookup overhead.
+static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
+static const Op& compiler_end_op = Op::Get("annotation.compiler_end");
+
+/*! \brief This is a pre-requisite pass to merge-supported pass.
+ *  The AnnotateRestDefault pass will put "default" Compiler Annotations to
+ *  nodes that are not annotated already. This is there to ensure that the
+ *  user will not leave un-annotated nodes MergeCompilerRegions pass is run.
+ *  Why? Because, MergeCompilerRegions pass assumes every node to be annotated.
+ */
+class AnnotateRestDefault : public ExprMutator {
+ public:
+  explicit AnnotateRestDefault(const Expr& expr) {
+  regions_ = AnnotatedRegionSet::Create(expr, compiler_begin_op, 
compiler_end_op);
+  }
+
+  Expr Annotate(const Expr& expr) {
+// Its a function that is being passed on to annotate
+func_ = Downcast(expr);
+
+// Corner Case CC1 : If the last node does not belong
+// to a region nede to add a compiler_end
+auto region = regions_->GetRegion(func_->body);
+auto mutated_expr = this->VisitExpr(expr);
+if (!region.defined()) {
+  func_ = Downcast(mutated_expr);
+  // CC1 : add that compiler end after mutation
+  auto body = AddCompilerEnd_(func_->body);
+  func_ = Function(func_->params, body,
+   body->checked_type_, {}, DictAttrs());
+  return Downcast(func_);
+} else {
+  return mutated_expr;
+}
+  }
+
+  /*! \brief This function adds compiler ends to nodes that
+   * have a region AND they should not be arguments of the
+   * original function
+   * \param expr
+   * \return expr
+   */
+  Expr AddCompilerEnd(const Expr& expr) {
+auto region = regions_->GetRegion(expr);
+auto visited_expr = VisitExpr(expr);
+
+// The compiler ends are added to nodes that does have a region
+// AND they should not be arguments of the original function
+if (!region.defined() &&
+   std::find(func_->params.begin(),
+ func_->params.end(), visited_expr)
+   == func_->params.end()) {
+  return AddCompilerEnd_(visited_expr);
+} else {
+  return visited_expr;
+}
+  }
+
+  Expr AddCompilerEnd_(const Expr& expr) {
+const auto* end_op =
+  runtime::Registry::Get("relay.op.annotation._make.compiler_end");
+CHECK(end_op);
+Expr end = (*end_op)(expr, target_);
+return end;
+  }
+
+  Expr VisitExpr_(const CallNode* call) final {
+auto op_node = call->op.as();
+auto ret = GetRef(call);
+
+Array args;
+
+// Add compiler ends if the parent is supported
+for (auto arg : call->args) {
+  args.push_back(AddCompilerEnd(arg));
+}
+
+if (op_node == nullptr || call->attrs.as() == nullptr) {
+  // Skip annotatation ops, only add default compiler to actual compute 
nodes
+
+  auto region = regions_->GetRegion(ret);
+  if (!region.defined()) {
+// if the current node does not belong to annotated region
+// annotate the all incoming edges (args)
+// with