Lunderberg commented on code in PR #17075:
URL: https://github.com/apache/tvm/pull/17075#discussion_r1633589835


##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
 
 namespace relax {
 
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& 
inplace_indices,
+                                              int num_inputs) {
+  Array<Integer> ret;
+  int last_idx = num_inputs;
+  for (auto idx : inplace_indices) {
+    int i = idx.IntValue();
+    if (i >= 0) {
+      ret.push_back(Integer(i));
+    } else {
+      ret.push_back(Integer(last_idx));
+      last_idx++;
+    }
+  }
+
+  return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+  void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool 
in_place = false) {
+    GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+    tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+    const auto& buffer_map = prim_func_->buffer_map;
+    const auto& tir_args = prim_func_->params;
+
+    const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
+
+    Array<Expr> relax_results;
+    if (lhs_var->IsInstance<TupleNode>()) {
+      relax_results = Downcast<Tuple>(lhs_var)->fields;
+    } else {
+      CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be 
either tuple or var";
+      relax_results = {Downcast<Var>(lhs_var)};
+    }
+
+    size_t num_inputs = relax_args.size();
+    size_t num_outputs = relax_results.size();
+
+    Array<Integer> output_idxs;
+    if (in_place) {
+      const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+      CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+      output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, 
num_inputs);
+    } else {
+      for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
+        output_idxs.push_back(i);
+      }
+    }
+    for (size_t i = 0; i < tir_args.size(); ++i) {
+      const auto& tir_var = Downcast<tir::Var>(tir_args[i]);
+      if (i < num_inputs) {
+        const auto& relax_var = Downcast<Var>(relax_args[i]);

Review Comment:
   This `Downcast<Var>` is not guaranteed to work.  While the normalizer will 
pull most `relax.Var` instances out to their own variable binding, `R.const` 
arguments may still appear inline.



##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
 
 namespace relax {
 
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& 
inplace_indices,
+                                              int num_inputs) {
+  Array<Integer> ret;
+  int last_idx = num_inputs;
+  for (auto idx : inplace_indices) {
+    int i = idx.IntValue();
+    if (i >= 0) {
+      ret.push_back(Integer(i));
+    } else {
+      ret.push_back(Integer(last_idx));
+      last_idx++;
+    }
+  }
+
+  return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+  void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool 
in_place = false) {
+    GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+    tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+    const auto& buffer_map = prim_func_->buffer_map;
+    const auto& tir_args = prim_func_->params;
+
+    const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
+
+    Array<Expr> relax_results;
+    if (lhs_var->IsInstance<TupleNode>()) {
+      relax_results = Downcast<Tuple>(lhs_var)->fields;
+    } else {
+      CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be 
either tuple or var";
+      relax_results = {Downcast<Var>(lhs_var)};
+    }
+
+    size_t num_inputs = relax_args.size();
+    size_t num_outputs = relax_results.size();
+
+    Array<Integer> output_idxs;
+    if (in_place) {
+      const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+      CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+      output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, 
num_inputs);
+    } else {
+      for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
+        output_idxs.push_back(i);
+      }
+    }
+    for (size_t i = 0; i < tir_args.size(); ++i) {
+      const auto& tir_var = Downcast<tir::Var>(tir_args[i]);
+      if (i < num_inputs) {
+        const auto& relax_var = Downcast<Var>(relax_args[i]);
+        relax_to_tir_var_map_.Set(relax_var, buffer_map[tir_var]);
+      }
+      if (auto it = std::find(output_idxs.begin(), output_idxs.end(), i); it 
!= output_idxs.end()) {
+        int result_idx = it - output_idxs.begin();
+        const auto& inplace_out_var = Downcast<Var>(relax_results[result_idx]);
+        relax_to_tir_var_map_.Set(inplace_out_var, buffer_map[tir_var]);
+      }
+    }
+  }
+
+ public:
+  explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {}
+  static Map<Var, tir::Buffer> Collect(const IRModule& mod, const Function& 
func) {
+    RelaxToTIRVarMapCollector visitor(mod);
+    visitor(func->body);
+    return visitor.relax_to_tir_var_map_;
+  }
+  void VisitBinding_(const VarBindingNode* binding) final {
+    const auto& lhs_var = binding->var;

Review Comment:
   As written, this would only collect the mapping for relax variables whose 
binding occurs in the outermost `SeqExpr`.  Nested `SeqExpr` may occur if 
`binding->value` is a `relax::If` node, where each branch then contains a 
`SeqExpr`.
   
   To resolve this, I'd recommend either adding 
`ExprVisitor::VisitBinding_(binding);` or `VisitExpr(binding->value)` to this 
method.



##########
tests/python/relax/test_transform_fuse_tir.py:
##########
@@ -2314,5 +2314,88 @@ def take(
     _check(Before, Before)
 
 
+def test_fuse_with_axis_separators():
+    @I.ir_module
+    class Before:
+        @T.prim_func(private=True)
+        def add(a: T.handle, b: T.handle, c: T.handle):
+            A = T.match_buffer(a, [T.int64(16), T.int64(32)], "float32", 
axis_separators=[1])
+            B = T.match_buffer(b, [T.int64(16), T.int64(32)], "float32", 
axis_separators=[1])
+            C = T.match_buffer(c, [T.int64(16), T.int64(32)], "float32", 
axis_separators=[1])
+
+            for iters in T.grid(T.int64(16), T.int64(32)):
+                with T.block("compute"):
+                    i, j = T.axis.remap("SS", iters)
+                    C[i, j] = A[i, j] + B[i, j]
+
+        @R.function(private=True)
+        def fused_function(
+            x: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+            y: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+            z: R.Tensor([T.int64(16), T.int64(32)], "float32"),
+        ) -> R.Tensor([T.int64(16), T.int64(32)], dtype="float32"):
+            R.func_attr({"Primitive": 1})
+            cls = Before
+            with R.dataflow():
+                w = R.call_tir(
+                    cls.add, [x, y], out_sinfo=R.Tensor([T.int64(16), 
T.int64(32)], "float32")

Review Comment:
   Can we add a test case for incompatible usage of a single Relax var?  As 
currently written, we could have a single Relax variable that is used in two 
separate `R.call_tir` statements, where the function being called imposes 
different restrictions on it.  For example, if `x` were used in `cls.add1`, 
which requires `axis_separators=[1]`, and `cls.add2`, which requires 
`axis_separators=[]`.  We should be able to identify this case and raise an 
error when it occurs.
   
   (Ideally, that should never happen, but this would be the last point at 
which we'd have enough information to catch this failure mode at compile-time.)



##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
 
 namespace relax {
 
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& 
inplace_indices,
+                                              int num_inputs) {
+  Array<Integer> ret;
+  int last_idx = num_inputs;
+  for (auto idx : inplace_indices) {
+    int i = idx.IntValue();
+    if (i >= 0) {
+      ret.push_back(Integer(i));
+    } else {
+      ret.push_back(Integer(last_idx));
+      last_idx++;
+    }
+  }
+
+  return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+  void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool 
in_place = false) {
+    GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+    tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+    const auto& buffer_map = prim_func_->buffer_map;
+    const auto& tir_args = prim_func_->params;
+
+    const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
+
+    Array<Expr> relax_results;
+    if (lhs_var->IsInstance<TupleNode>()) {
+      relax_results = Downcast<Tuple>(lhs_var)->fields;
+    } else {
+      CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be 
either tuple or var";
+      relax_results = {Downcast<Var>(lhs_var)};
+    }
+
+    size_t num_inputs = relax_args.size();
+    size_t num_outputs = relax_results.size();
+
+    Array<Integer> output_idxs;
+    if (in_place) {
+      const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+      CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+      output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, 
num_inputs);
+    } else {
+      for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
+        output_idxs.push_back(i);
+      }
+    }
+    for (size_t i = 0; i < tir_args.size(); ++i) {
+      const auto& tir_var = Downcast<tir::Var>(tir_args[i]);
+      if (i < num_inputs) {
+        const auto& relax_var = Downcast<Var>(relax_args[i]);
+        relax_to_tir_var_map_.Set(relax_var, buffer_map[tir_var]);

Review Comment:
   The `buffer_map` does not necessarily contain an entry for `tir_var`.  For 
example, the `relax_var` could have `PrimStructInfo` to pass a primitive scalar 
to the TIR funciton.  Even if `relax_var` has `TensorStructInfo`, the TIR 
function may treat the `DLTensor*` as an opaque pointer, passing it to a 
`PackedFunc` without having an entry in the `buffer_map`.
   
   The best way to handle these cases is to wrap this line in a `if(auto 
tir_buffer = buffer_map.Get(tir_var))` conditional, and then use 
`tir_buffer.value()` inside the conditional instead of `buffer_map[tir_var]`.



##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
 
 namespace relax {
 
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& 
inplace_indices,
+                                              int num_inputs) {
+  Array<Integer> ret;
+  int last_idx = num_inputs;
+  for (auto idx : inplace_indices) {
+    int i = idx.IntValue();
+    if (i >= 0) {
+      ret.push_back(Integer(i));
+    } else {
+      ret.push_back(Integer(last_idx));

Review Comment:
   Nitpick: This should only apply for `i==-1`.  For any other negative value, 
we should raise an error.



##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
 
 namespace relax {
 
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& 
inplace_indices,
+                                              int num_inputs) {
+  Array<Integer> ret;
+  int last_idx = num_inputs;
+  for (auto idx : inplace_indices) {
+    int i = idx.IntValue();
+    if (i >= 0) {
+      ret.push_back(Integer(i));
+    } else {
+      ret.push_back(Integer(last_idx));
+      last_idx++;
+    }
+  }
+
+  return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+  void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool 
in_place = false) {
+    GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+    tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+    const auto& buffer_map = prim_func_->buffer_map;
+    const auto& tir_args = prim_func_->params;
+
+    const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
+
+    Array<Expr> relax_results;
+    if (lhs_var->IsInstance<TupleNode>()) {
+      relax_results = Downcast<Tuple>(lhs_var)->fields;
+    } else {
+      CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be 
either tuple or var";
+      relax_results = {Downcast<Var>(lhs_var)};
+    }
+
+    size_t num_inputs = relax_args.size();
+    size_t num_outputs = relax_results.size();
+
+    Array<Integer> output_idxs;
+    if (in_place) {
+      const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+      CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+      output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, 
num_inputs);
+    } else {
+      for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
+        output_idxs.push_back(i);
+      }
+    }
+    for (size_t i = 0; i < tir_args.size(); ++i) {
+      const auto& tir_var = Downcast<tir::Var>(tir_args[i]);
+      if (i < num_inputs) {
+        const auto& relax_var = Downcast<Var>(relax_args[i]);
+        relax_to_tir_var_map_.Set(relax_var, buffer_map[tir_var]);
+      }
+      if (auto it = std::find(output_idxs.begin(), output_idxs.end(), i); it 
!= output_idxs.end()) {
+        int result_idx = it - output_idxs.begin();
+        const auto& inplace_out_var = Downcast<Var>(relax_results[result_idx]);
+        relax_to_tir_var_map_.Set(inplace_out_var, buffer_map[tir_var]);
+      }
+    }
+  }
+
+ public:
+  explicit RelaxToTIRVarMapCollector(const IRModule& mod) : mod_(mod) {}
+  static Map<Var, tir::Buffer> Collect(const IRModule& mod, const Function& 
func) {
+    RelaxToTIRVarMapCollector visitor(mod);
+    visitor(func->body);
+    return visitor.relax_to_tir_var_map_;
+  }
+  void VisitBinding_(const VarBindingNode* binding) final {
+    const auto& lhs_var = binding->var;
+    const auto& value = binding->value;
+    if (const CallNode* call = value.as<CallNode>()) {
+      static const Op& call_tir_op_ = Op::Get("relax.call_tir");
+      static const Op& call_tir_inplace_op_ = 
Op::Get("relax.call_tir_inplace");
+
+      ICHECK(call->op == call_tir_op_ || call->op == call_tir_inplace_op_)
+          << "Only call_tir and call_tir_inplace are supported in primitive 
function, but got: "
+          << GetRef<Expr>(call);
+      if (call->op == call_tir_inplace_op_) {
+        CollectVarMapping(call, lhs_var, /*in_place*/ true);
+      } else {
+        CollectVarMapping(call, lhs_var);
+      }
+    }
+  }
+
+ private:
+  /*! \brief The IRModule */
+  const IRModule& mod_;
+  // size_t call_num_inputs_ = -1;
+  Map<Var, tir::Buffer> relax_to_tir_var_map_;

Review Comment:
   This data structure assumes that there is a 1:1 mapping from `relax::Var` to 
`tir::Buffer` across the entire fused function.  This would have incorrect 
results for cases where the same tensor is used as multiple arguments (e.g. 
`R.add(A, A)`), or where the same tensor is used as an argument to more than 
one function (e.g. The tensor `A` corresponds to two different TIR buffers in 
the sequence `mean = R.mean(A); norm = R.sqrt(mean); A_norm = R.divide(A, 
norm)`).



##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
 
 namespace relax {
 
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& 
inplace_indices,
+                                              int num_inputs) {
+  Array<Integer> ret;
+  int last_idx = num_inputs;
+  for (auto idx : inplace_indices) {
+    int i = idx.IntValue();
+    if (i >= 0) {
+      ret.push_back(Integer(i));
+    } else {
+      ret.push_back(Integer(last_idx));
+      last_idx++;
+    }
+  }
+
+  return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+  void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool 
in_place = false) {

Review Comment:
   Nitpick: For readability, having the public-facing interface at the top of 
the class makes it easier to find the entry point.  Can the `static Map<Var, 
tir::Buffer> Collect` function be moved to the top of 
`RelaxToTIRVarMapCollector`?



##########
src/relax/transform/fuse_tir.cc:
##########
@@ -391,10 +484,15 @@ class FusedTIRConstructor : public ExprVisitor {
       : mod_(mod), func_name_(func_name) {}
 
   void VisitExpr_(const FunctionNode* func) final {
+    auto relax_to_tir_var_map = RelaxToTIRVarMapCollector::Collect(mod_, 
GetRef<Function>(func));
     std::vector<Variant<tir::Var, tir::Buffer>> prim_func_params;
     for (const Var& relax_param : func->params) {
       size_t size_before = prim_func_params.size();
-      CollectPrimFuncParams(relax_param, &prim_func_params);
+      if (relax_to_tir_var_map.count(relax_param)) {

Review Comment:
   Nitpick: Instead of having the conditional, both branches could be written 
as `CollectPrimFuncParams(relax_param, &prim_func_params, 
relax_to_tir_var_map.Get(relax_param))`.  The `Map::Get` method returns an 
`Optional<tir::Buffer>`, which is `NullOpt` when the key is absent from the map.



##########
src/relax/transform/fuse_tir.cc:
##########
@@ -362,6 +362,99 @@ class BlockNameDeduplicator : public tir::StmtMutator {
 
 namespace relax {
 
+static Array<Integer> GetInplaceOutputIndices(const Array<Integer>& 
inplace_indices,
+                                              int num_inputs) {
+  Array<Integer> ret;
+  int last_idx = num_inputs;
+  for (auto idx : inplace_indices) {
+    int i = idx.IntValue();
+    if (i >= 0) {
+      ret.push_back(Integer(i));
+    } else {
+      ret.push_back(Integer(last_idx));
+      last_idx++;
+    }
+  }
+
+  return ret;
+}
+
+class RelaxToTIRVarMapCollector : public ExprVisitor {
+  void CollectVarMapping(const CallNode* call, const Expr& lhs_var, bool 
in_place = false) {
+    GlobalVar gv = Downcast<GlobalVar>(call->args[0]);
+    tir::PrimFunc prim_func_ = Downcast<tir::PrimFunc>(mod_->Lookup(gv));
+    const auto& buffer_map = prim_func_->buffer_map;
+    const auto& tir_args = prim_func_->params;
+
+    const auto& relax_args = Downcast<Tuple>(call->args[1])->fields;
+
+    Array<Expr> relax_results;
+    if (lhs_var->IsInstance<TupleNode>()) {
+      relax_results = Downcast<Tuple>(lhs_var)->fields;
+    } else {
+      CHECK(lhs_var->IsInstance<VarNode>()) << "The lhs_var is expected to be 
either tuple or var";
+      relax_results = {Downcast<Var>(lhs_var)};
+    }
+
+    size_t num_inputs = relax_args.size();
+    size_t num_outputs = relax_results.size();
+
+    Array<Integer> output_idxs;
+    if (in_place) {
+      const auto* attrs = call->attrs.as<CallTIRInplaceAttrs>();
+      CHECK(attrs) << "Must have CallTIRInplaceAttrs for an in-place call";
+      output_idxs = GetInplaceOutputIndices(attrs->inplace_indices, 
num_inputs);
+    } else {
+      for (size_t i = num_inputs; i < num_inputs + num_outputs; i++) {
+        output_idxs.push_back(i);
+      }
+    }
+    for (size_t i = 0; i < tir_args.size(); ++i) {
+      const auto& tir_var = Downcast<tir::Var>(tir_args[i]);

Review Comment:
   This `Downcast<tir::Var>` is unnecessary, because `prim_func->params` is 
already an array of `tir::Var`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to