quic-sanirudh commented on code in PR #17075: URL: https://github.com/apache/tvm/pull/17075#discussion_r1635041274
########## src/relax/transform/fuse_tir.cc: ########## @@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator { namespace relax { +static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& inplace_indices, + int num_inputs) { + Array<Integer> ret; + int last_idx = num_inputs; + for (auto idx : inplace_indices) { + int i = idx.IntValue(); + if (i >= 0) { + ret.push_back(Integer(i)); + } else { + ret.push_back(Integer(last_idx)); + last_idx++; + } + } + + return ret; +} + +class RelaxToTIRVarMapCollector : public ExprVisitor { + void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool in_place = false) { + GlobalVar gv = Downcast<GlobalVar>(call->args[0]); + tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv)); + const auto& buffer_map = prim_func_->buffer_map; + const auto& tir_args = prim_func_->params; + + const auto& relax_args = Downcast<Tuple>(call->args[1])->fields; + + Array<Expr> relax_results; + if (lhs_var->IsInstance<TupleNode>()) { + relax_results = Downcast<Tuple>(lhs_var)->fields; + } else { + CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be either tuple or var"; + relax_results = {Downcast<Var>(lhs_var)}; + } + + size_t num_inputs = relax_args.size(); + size_t num_outputs = relax_results.size(); + + Array<Integer> output_idxs; + if (in_place) { + const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>(); + CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call"; + output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, num_inputs); + } else { + for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) { + output_idxs.push_back(i); + } + } + for (size_t i = 0; i < tir_args.size(); ++i) { + const auto& tir_var = Downcast<tir::Var>(tir_args[i]); + if (i < num_inputs) { + const auto& relax_var = Downcast<Var>(relax_args[i]); + relax_to_tir_var_map_.Set(relax_var, buffer_map[tir_var]); + } + if (auto it = std::find(output_idxs.begin(), output_idxs.end(), i); it != output_idxs.end()) { + int result_idx = it - output_idxs.begin(); + const auto& inplace_out_var = Downcast<Var>(relax_results[result_idx]); + relax_to_tir_var_map_.Set(inplace_out_var, buffer_map[tir_var]); + } + } + } + + public: + explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {} + static Map<Var, tir::Buffer> Collect(const IRModule& mod, const Function& func) { + RelaxToTIRVarMapCollector visitor(mod); + visitor(func->body); + return visitor.relax_to_tir_var_map_; + } + void VisitBinding_(const VarBindingNode* binding) final { + const auto& lhs_var = binding->var; + const auto& value = binding->value; + if (const CallNode* call = value.as<CallNode>()) { + static const Op& call_tir_op_ = Op::Get("relax.call_tir"); + static const Op& call_tir_inplace_op_ = Op::Get("relax.call_tir_inplace"); + + ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_) + << "Only call_tir and call_tir_inplace are supported in primitive function, but got: " + << GetRef<Expr>(call); + if (call->op == call_tir_inplace_op_) { + CollectVarMapping(call, lhs_var, /*in_place*/ true); + } else { + CollectVarMapping(call, lhs_var); + } + } + } + + private: + /*! \brief The IRModule */ + const IRModule& mod_; + // size_t call_num_inputs_ = -1; + Map<Var, tir::Buffer> relax_to_tir_var_map_; Review Comment: Yes, I did think about this issue, but I assumed that even though the same relax var might map to different buffers, it should have the same buffer attributes (since its source is the same relax var). I've also added a validation ICHECK to verify that the buffer attributes match (using structural equal). I've also added a testcase to verify this use case as suggested in the below comment. ########## src/relax/transform/fuse_tir.cc: ########## @@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator { namespace relax { +static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& inplace_indices, + int num_inputs) { + Array<Integer> ret; + int last_idx = num_inputs; + for (auto idx : inplace_indices) { + int i = idx.IntValue(); + if (i >= 0) { + ret.push_back(Integer(i)); + } else { + ret.push_back(Integer(last_idx)); + last_idx++; + } + } + + return ret; +} + +class RelaxToTIRVarMapCollector : public ExprVisitor { + void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool in_place = false) { Review Comment: Done -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org