quic-sanirudh commented on code in PR #17075:
URL: https://github.com/apache/tvm/pull/17075#discussion_r1635041274


##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
 
 namespace relax {
 
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& 
inplace_indices,
+                                              int num_inputs) {
+  Array<Integer> ret;
+  int last_idx = num_inputs;
+  for (auto idx : inplace_indices) {
+    int i = idx.IntValue();
+    if (i >= 0) {
+      ret.push_back(Integer(i));
+    } else {
+      ret.push_back(Integer(last_idx));
+      last_idx++;
+    }
+  }
+
+  return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+  void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool 
in_place = false) {
+    GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+    tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+    const auto& buffer_map = prim_func_->buffer_map;
+    const auto& tir_args = prim_func_->params;
+
+    const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
+
+    Array<Expr> relax_results;
+    if (lhs_var->IsInstance<TupleNode>()) {
+      relax_results = Downcast<Tuple>(lhs_var)->fields;
+    } else {
+      CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be 
either tuple or var";
+      relax_results = {Downcast<Var>(lhs_var)};
+    }
+
+    size_t num_inputs = relax_args.size();
+    size_t num_outputs = relax_results.size();
+
+    Array<Integer> output_idxs;
+    if (in_place) {
+      const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+      CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+      output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, 
num_inputs);
+    } else {
+      for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
+        output_idxs.push_back(i);
+      }
+    }
+    for (size_t i = 0; i < tir_args.size(); ++i) {
+      const auto& tir_var = Downcast<tir::Var>(tir_args[i]);
+      if (i < num_inputs) {
+        const auto& relax_var = Downcast<Var>(relax_args[i]);
+        relax_to_tir_var_map_.Set(relax_var, buffer_map[tir_var]);
+      }
+      if (auto it = std::find(output_idxs.begin(), output_idxs.end(), i); it 
!= output_idxs.end()) {
+        int result_idx = it - output_idxs.begin();
+        const auto& inplace_out_var = Downcast<Var>(relax_results[result_idx]);
+        relax_to_tir_var_map_.Set(inplace_out_var, buffer_map[tir_var]);
+      }
+    }
+  }
+
+ public:
+  explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {}
+  static Map<Var, tir::Buffer> Collect(const IRModule& mod, const Function& 
func) {
+    RelaxToTIRVarMapCollector visitor(mod);
+    visitor(func->body);
+    return visitor.relax_to_tir_var_map_;
+  }
+  void VisitBinding_(const VarBindingNode* binding) final {
+    const auto& lhs_var = binding->var;
+    const auto& value = binding->value;
+    if (const CallNode* call = value.as<CallNode>()) {
+      static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+      static const Op& call_tir_inplace_op_ = 
Op::Get("relax.call_tir_inplace");
+
+      ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
+          << "Only call_tir and call_tir_inplace are supported in primitive 
function, but got: "
+          << GetRef<Expr>(call);
+      if (call->op == call_tir_inplace_op_) {
+        CollectVarMapping(call, lhs_var, /*in_place*/ true);
+      } else {
+        CollectVarMapping(call, lhs_var);
+      }
+    }
+  }
+
+ private:
+  /*! \brief The IRModule */
+  const IRModule& mod_;
+  // size_t call_num_inputs_ = -1;
+  Map<Var, tir::Buffer> relax_to_tir_var_map_;

Review Comment:
   Yes, I did think about this issue, but I assumed that even though the same 
relax var might map to different buffers, it should have the same buffer 
attributes (since its source is the same relax var). I've also added a 
validation ICHECK to verify that the buffer attributes match (using structural 
equal).
   
   I've also added a testcase to verify this use case as suggested in the below 
comment.



##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
 
 namespace relax {
 
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& 
inplace_indices,
+                                              int num_inputs) {
+  Array<Integer> ret;
+  int last_idx = num_inputs;
+  for (auto idx : inplace_indices) {
+    int i = idx.IntValue();
+    if (i >= 0) {
+      ret.push_back(Integer(i));
+    } else {
+      ret.push_back(Integer(last_idx));
+      last_idx++;
+    }
+  }
+
+  return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+  void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool 
in_place = false) {

Review Comment:
   Done



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to