windclarion commented on a change in pull request #5277: [BYOC] Refine
AnnotateTarget and MergeCompilerRegion Passes
URL: https://github.com/apache/incubator-tvm/pull/5277#discussion_r405903872
##########
File path: src/relay/transforms/annotate_target.cc
##########
@@ -19,131 +19,155 @@
/*!
* \file src/relay/transforms/annotate_target.cc
- * \brief Wraps a call with compiler_begin and compiler_end to indicate that
- * the op of this call node will use external compiler.
+ * \brief Wraps an expr with compiler_begin and compiler_end to indicate that
+ * this expr should be handled by the external compiler.
*/
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/transform.h>
+#include <tvm/runtime/container.h>
namespace tvm {
namespace relay {
namespace annotate_target {
-// Cache compiler_begin op for equivalence check.
-static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
+const PackedFunc* begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
+const PackedFunc* end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
// A helper class to insert annotation boundaries for a program region that
will
// be handled by a specific compiler.
class AnnotateTargetWrapper : public ExprMutator {
public:
- explicit AnnotateTargetWrapper(const std::string& target) : target_(target)
{}
+ explicit AnnotateTargetWrapper(Array<runtime::String> targets) :
targets_(std::move(targets)) {}
+
+ /*!
+ * \brief This function annotates a compiler end and a compiler begin to all
arguments.
+ *
+ * The compiler end is based on the arg target while the compiler begin is
based on the given
+ * target. If target is not given and all arguments are going to the same
target, then we will
+ * use that target; otherwise we use default for this op. Note that all arg
exprs must be
+ * available in op_expr_to_target before calling this function.
+ *
+ * \param args An array of arguments of the given node.
+ * \param target The target of the current node.
+ * \return A pair of target and annotated argument expressions.
+ */
+ std::pair<std::string, Array<Expr>> AnnotateArgs(const Array<Expr>& args,
+ const std::string& target =
"") {
+ std::string ref_target = "";
+ Array<Expr> compiler_ends;
+ for (auto arg : args) {
+ if (op_expr_to_target_.find(arg) != op_expr_to_target_.end()) {
+ std::string arg_target = op_expr_to_target_[arg];
+ compiler_ends.push_back(InsertAnnotation(arg, arg_target, end_op));
+ if (ref_target == "") {
+ ref_target = arg_target;
+ } else if (ref_target != arg_target) {
+ ref_target = "default";
+ }
+ } else {
+ // Input vars.
+ compiler_ends.push_back(arg);
+ }
+ }
+
+ // Determine compiler begin target.
+ std::string op_target = (target == "") ? ref_target : target;
+
+ Array<Expr> compiler_begins;
+ for (const auto& end : compiler_ends) {
+ compiler_begins.push_back(InsertAnnotation(end, op_target, begin_op));
+ }
- Expr Annotate(const Expr& expr) {
- return InsertEnd(Mutate(expr));
+ return {op_target, compiler_begins};
}
- bool IsSupported(const Expr& expr) {
- if (expr->IsInstance<CallNode>()) {
- Call call = Downcast<Call>(expr);
- auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
- if (call->op->IsInstance<OpNode>()) {
- Op op = Downcast<Op>(call->op);
- CHECK(op.defined());
- if (fannotate.count(op)) {
- return fannotate[op](call->attrs, call->args);
+ Expr InsertAnnotation(const Expr& expr, const std::string& target, const
PackedFunc* ann_op) {
+ Expr new_op = (*ann_op)(expr, target);
+ new_op->checked_type_ = expr->checked_type_;
+ return new_op;
+ }
+
+ Expr VisitExpr_(const CallNode* cn) final {
+ // Supported targets for this node. The order implies the priority.
+ std::vector<std::string> supported_targets;
+
+ // Check which targets this op can be offloaded.
+ if (cn->op->IsInstance<OpNode>()) {
+ // TVM operators: Check target specific op checking function and add to
supported_targets
+ // if it is supported.
+ Op op = Downcast<Op>(cn->op);
+ CHECK(op.defined());
+ for (const auto& target : this->targets_) {
+ auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." +
std::string(target));
Review comment:
I understand, but what I said maybe is not clear. I know composite function
doesn't jump into OpNode branch, but I mean I only use composite function, so I
don't define any op FTVMAnnotateTarget attr for target.xxxx, and for any
OpNodeļ¼auto fannotate = Op::GetAttr("target." + std::string(target)) will
report fail, because I didn't define any FTVMAnnotateTarget attr before.
annotate mechnism can handle op and composite function, and the two is
independent for each other.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services