This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 4d19c8ab1f [Unity][Transform] Improved canonicalization of 
non-dataflow Var (#15941)
4d19c8ab1f is described below

commit 4d19c8ab1f2dd03b1cd6f7eff10eba020867e4b4
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Wed Oct 25 09:13:33 2023 -0500

    [Unity][Transform] Improved canonicalization of non-dataflow Var (#15941)
    
    * [Unity][Transform] Improved canonicalization of non-dataflow Var
    
    Prior to this commit, `relax.transform.CanonicalizeBindings` removed
    trivial bindings `var_y = var_x` where a `var_y: relax.DataflowVar`
    and `var_x: relax.Var`, but did not remove trivial bindings when
    `var_y: relax.Var` and `var_x: relax.DataflowVar`.  This was to avoid
    invalid use of a `relax.DataflowVar` outside of a dataflow block.
    
    This commit updates `CanonicalizeBindings` to handle this type of
    binding as well.  To ensure that no `relax.DataflowVar` instances are
    used outside of a dataflow block, this is done by replacing `var_y:
    relax.DataflowVar` at its point of definition, instead of replacing
    `var_x: relax.Var` at its point of use.
    
    This commit also canonicalizes `relax.Var` definitions to
    `relax.DataflowVar`, if the binding occurs within a dataflow block,
    and the variable is never used outside of a dataflow block.
    
    * Simplify unwrapping of known bindings
    
    * Updated to use Map<Id,Var>, to avoid while(true) loops
---
 src/relax/transform/canonicalize_bindings.cc       | 247 +++++++++++++++------
 tests/python/relax/test_dataflow_pattern.py        |   3 +-
 .../python/relax/test_optimize_layout_transform.py |   9 +-
 .../python/relax/test_remove_redundant_reshape.py  |   3 +-
 .../relax/test_transform_canonicalize_bindings.py  | 244 +++++++++++++++++++-
 5 files changed, 414 insertions(+), 92 deletions(-)

diff --git a/src/relax/transform/canonicalize_bindings.cc 
b/src/relax/transform/canonicalize_bindings.cc
index 2e7f4311f9..246b38f6f8 100644
--- a/src/relax/transform/canonicalize_bindings.cc
+++ b/src/relax/transform/canonicalize_bindings.cc
@@ -33,68 +33,187 @@
 namespace tvm {
 namespace relax {
 
-class BindingCanonicalizer : public ExprMutator {
+namespace {
+
+struct CanonicalizationPlan {
+  Map<Id, Var> replace_usage;
+  Map<Id, Var> replace_binding;
+  std::unordered_set<Id, ObjectPtrHash, ObjectPtrEqual> bindings_to_remove;
+};
+
+/*! \brief Utility class to identify usage location
+ *
+ * Canonicalization of a variable binding may require information from
+ * later in the function.  For example, replacing `dataflow_x = expr`
+ * with `var_x = expr` to avoid a trivial binding of `var_x =
+ * dataflow_x` later in the function.  This utility examines a relax
+ * expression, and plans the changes to be made in a mutation pass.
+ */
+class CanonicalizePlanner : public ExprVisitor {
  public:
-  BindingCanonicalizer() {}
-
-  using ExprMutator::VisitExpr_;
-
-  Expr VisitExpr_(const TupleGetItemNode* tuple_get_item) override {
-    if (auto tuple_var = tuple_get_item->tuple.as<Var>()) {
-      if (auto tuple_value = LookupBinding(tuple_var.value())) {
-        if (auto explicit_tuple = tuple_value.as<TupleNode>()) {
-          CHECK_GE(tuple_get_item->index, 0)
-              << "Tuple " << tuple_value << " is accessed at index " << 
tuple_get_item->index
-              << ", but negative indices are not supported in this context.";
-          CHECK_LT(tuple_get_item->index, explicit_tuple->fields.size())
-              << "Tuple " << tuple_value << " is accessed at index " << 
tuple_get_item->index
-              << ", but the tuple size is only " << 
explicit_tuple->fields.size();
-          return VisitExpr(explicit_tuple->fields[tuple_get_item->index]);
+  static CanonicalizationPlan Collect(const Expr& expr) {
+    CanonicalizePlanner visitor;
+    visitor.VisitExpr(expr);
+
+    CanonicalizationPlan plan;
+
+    std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> handled;
+
+    for (const auto& binding_iter : visitor.trivial_bindings_) {
+      Var bound_var = binding_iter.first;
+      Var bound_to = binding_iter.second;
+
+      while (auto opt = visitor.trivial_bindings_.Get(bound_to)) {
+        // This may be a trivial binding into a trivial binding.  In
+        // that case, unwrap the bindings until we find the earliest
+        // non-trivial binding.
+        bound_to = opt.value();
+      }
+
+      while (auto opt = plan.replace_binding.Get(bound_to->vid)) {
+        // The variable we are binding to may have already been
+        // replaced, if it fell into Case 4 (Var = DataflowVar).  In
+        // that case, we check against its replacement instead.
+        bound_to = opt.value();
+      }
+
+      if (bound_var.as<DataflowVarNode>() || !bound_to.as<DataflowVarNode>()) {
+        // Case 1: Var = Var
+        // Case 2: DataflowVar = Var
+        // Case 3: DataflowVar = DataflowVar
+        //
+        // For these three cases, the trivial binding can be
+        // unwrapped, using the bound variable directly at the point
+        // of use.
+        plan.replace_usage.Set(bound_var->vid, bound_to);
+        plan.bindings_to_remove.insert(bound_var->vid);
+        handled.insert(bound_to);
+      } else {
+        // Case 4: Var = DataflowVar
+        //
+        // Replacing a Var with a DataflowVar could result in illegal
+        // use of a DataflowVar outside of a DataflowBlock.  Instead,
+        // we replace in the opposite direction, replacing the binding
+        // of the DataflowVar with a binding of the Var.
+        plan.replace_binding.Set(bound_to->vid, bound_var);
+        plan.replace_usage.Set(bound_to->vid, bound_var);
+        plan.bindings_to_remove.insert(bound_var->vid);
+        handled.insert(bound_var);
+      }
+    }
+
+    // If a Var has been defined inside a DataflowBlock, is only used
+    // within a DataflowBlock, and is not already handled by removal
+    // of trivial bindings, then we can replace it with a DataflowVar.
+    for (const auto& var : visitor.defined_inside_dataflow_) {
+      if (!var.as<DataflowVarNode>() && 
!visitor.used_outside_dataflow_.count(var) &&
+          !handled.count(var)) {
+        DataflowVar new_var(var->name_hint(), GetStructInfo(var));
+        plan.replace_binding.Set(var->vid, new_var);
+        plan.replace_usage.Set(var->vid, new_var);
+      }
+    }
+
+    return plan;
+  }
+
+ private:
+  void VisitBindingBlock_(const DataflowBlockNode* block) override {
+    bool cache = inside_dataflow_;
+    inside_dataflow_ = true;
+    ExprVisitor::VisitBindingBlock_(block);
+    inside_dataflow_ = cache;
+  }
+
+  void VisitBinding(const Binding& binding) override {
+    bool has_same_struct_info = true;
+    Expr value;
+    if (auto ptr = binding.as<VarBindingNode>()) {
+      value = ptr->value;
+    } else if (auto ptr = binding.as<MatchCastNode>()) {
+      has_same_struct_info =
+          StructuralEqual()(GetStructInfo(binding->var), 
GetStructInfo(ptr->value));
+      value = ptr->value;
+    } else {
+      LOG(FATAL) << "Invalid binding type: " << binding->GetTypeKey();
+    }
+
+    // Unwrap TupleGetItem, if the Tuple being accessed is known.
+    if (auto tuple_get_item = value.as<TupleGetItemNode>()) {
+      Expr tuple = tuple_get_item->tuple;
+      while (auto tuple_var = tuple.as<Var>()) {
+        if (auto opt = known_bindings_.Get(tuple_var.value())) {
+          tuple = opt.value();
+        } else {
+          break;
         }
       }
+
+      if (auto ptr = tuple.as<TupleNode>()) {
+        value = ptr->fields[tuple_get_item->index];
+      }
+    }
+
+    if (auto parent = value.as<Var>(); parent && has_same_struct_info) {
+      trivial_bindings_.Set(binding->var, parent.value());
     }
-    return ExprMutator::VisitExpr_(tuple_get_item);
+
+    known_bindings_.Set(binding->var, value);
+
+    ExprVisitor::VisitBinding(binding);
   }
 
-  void VisitBinding_(const VarBindingNode* binding) override {
-    // Unlike default visitor, we do not permit the struct info to change
-    // if the new value's struct info is different (this preserves user 
annotations)
-    Expr new_value = this->VisitExpr(binding->value);
-    Var new_var = this->VisitVarDef(binding->var);
-
-    if (auto opt_var = new_value.as<Var>();
-        opt_var && CanCanonicalizeVar(new_var, opt_var.value())) {
-      var_remap_[new_var->vid] = opt_var.value();
-    } else if (new_var.same_as(binding->var) && 
new_value.same_as(binding->value)) {
-      this->builder_->EmitNormalized(GetRef<VarBinding>(binding));
+  void VisitVarDef(const Var& var) override {
+    if (inside_dataflow_) {
+      defined_inside_dataflow_.insert(var);
+    }
+  }
+
+  void VisitExpr_(const VarNode* var) override {
+    if (!inside_dataflow_) {
+      used_outside_dataflow_.insert(GetRef<Var>(var));
+    }
+  }
+
+  bool inside_dataflow_{false};
+
+  Map<Var, Var> trivial_bindings_;
+  Map<Var, Expr> known_bindings_;
+  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> 
defined_inside_dataflow_;
+  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> 
used_outside_dataflow_;
+};
+
+/*! \brief The mutator class to apply a CanonicalizationPlan */
+class BindingCanonicalizer : public ExprMutator {
+ public:
+  static Expr Apply(Expr expr) {
+    auto used_outside_dataflow = CanonicalizePlanner::Collect(expr);
+    BindingCanonicalizer mutator(std::move(used_outside_dataflow));
+    return mutator.VisitExpr(expr);
+  }
+
+ private:
+  explicit BindingCanonicalizer(CanonicalizationPlan plan) : plan_(plan) {}
+
+  void VisitBinding(const Binding& binding) override {
+    if (!plan_.bindings_to_remove.count(binding->var->vid)) {
+      ExprMutator::VisitBinding(binding);
+    }
+  }
+
+  Var VisitVarDef(const Var& var) override {
+    if (auto opt = plan_.replace_binding.Get(var->vid)) {
+      return ExprMutator::VisitVarDef(opt.value());
     } else {
-      this->builder_->EmitNormalized(VarBinding(new_var, new_value));
+      return ExprMutator::VisitVarDef(var);
     }
   }
 
-  void VisitBinding_(const MatchCastNode* binding) override {
-    // If we have a trivial shape check (the struct_info_ of LHS and RHS is 
the same),
-    // we can canonicalize to a var binding
-    Expr new_value = this->VisitExpr(binding->value);
-    bool has_same_struct_info = StructuralEqual()(binding->struct_info, 
GetStructInfo(new_value));
-
-    if (has_same_struct_info) {
-      if (auto parent = new_value.as<Var>();
-          parent && CanCanonicalizeVar(binding->var, parent.value())) {
-        // LHS and RHS have the same struct info, and occur in a
-        // context where the RHS can replace the LHS.
-        var_remap_[binding->var->vid] = parent.value();
-      } else {
-        // LHS and RHS have the same struct info, but the RHS is not a
-        // drop-in replacement for the LHS.
-        builder_->EmitNormalized(VarBinding(binding->var, new_value));
-      }
-    } else if (new_value.same_as(binding->value)) {
-      builder_->EmitNormalized(GetRef<MatchCast>(binding));
+  Expr VisitExpr_(const VarNode* var) override {
+    if (auto opt = plan_.replace_usage.Get(var->vid)) {
+      return ExprMutator::VisitExpr(opt.value());
     } else {
-      // we can't elide in the same way as with var bindings because
-      // the struct info comparison has semantics
-      builder_->EmitNormalized(MatchCast(binding->var, new_value, 
binding->struct_info));
+      return ExprMutator::VisitExpr_(var);
     }
   }
 
@@ -200,31 +319,11 @@ class BindingCanonicalizer : public ExprMutator {
   }
 
  private:
-  bool AnnotationsDiffer(const ObjectRef& obj1, const ObjectRef& obj2,
-                         std::function<bool(const ObjectRef&, const 
ObjectRef&)> check_eq) {
-    // annotations differ if one is present but not the other
-    // or they're both present and they differ
-    bool both_present = obj1.defined() && obj2.defined();
-    bool neither_present = !obj1.defined() && !obj2.defined();
-    return !(both_present || neither_present) || (both_present && 
!check_eq(obj1, obj2));
-  }
-
-  bool CanCanonicalizeVar(Var var, Var parent_var) {
-    // Cases when we conservatively do not unify:
-    // 1. The struct_info_ of the child differs from that of the parent
-    //    In this case, we could be overriding user annotations.
-    // 2. If the child is a Var and the parent is a DataflowVar.
-    //    That could result in a DataflowVar leaving the current DataflowBlock.
-    bool annotations_differ = AnnotationsDiffer(var->struct_info_, 
parent_var->struct_info_,
-                                                [&](const ObjectRef& lhs, 
const ObjectRef& rhs) {
-                                                  return 
tvm::StructuralEqual()(lhs, rhs);
-                                                });
-    bool var_to_dataflow = (!var.as<DataflowVarNode>() && 
parent_var.as<DataflowVarNode>());
-    return !annotations_differ && !var_to_dataflow;
-  }
+  CanonicalizationPlan plan_;
 };
+}  // namespace
 
-Expr CanonicalizeBindings(const Expr& e) { return 
BindingCanonicalizer().VisitExpr(e); }
+Expr CanonicalizeBindings(const Expr& expr) { return 
BindingCanonicalizer::Apply(expr); }
 
 namespace transform {
 
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index 49b7d11a80..a8b71aa5eb 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1403,8 +1403,7 @@ def 
test_rewrite_without_trivial_binding(bind_to_dataflow_var):
         @R.function(private=True)
         def expected(x: R.Tensor((1024,))):
             with R.dataflow():
-                a = R.add(x, x)
-                b = a
+                b = R.add(x, x)
                 R.output(b)
             return b
 
diff --git a/tests/python/relax/test_optimize_layout_transform.py 
b/tests/python/relax/test_optimize_layout_transform.py
index 08c9e31107..3addfab2e8 100644
--- a/tests/python/relax/test_optimize_layout_transform.py
+++ b/tests/python/relax/test_optimize_layout_transform.py
@@ -130,10 +130,9 @@ def test_optimize_transform_layout_pass_one_arg():
                     (lv1, lv2),
                     out_sinfo=R.Tensor((4, 4), dtype="float32"),
                 )
-                lv2_1: R.Tensor((16,), dtype="float32") = R.layout_transform(
+                gv: R.Tensor((16,), dtype="float32") = R.layout_transform(
                     lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), 
pad_value=None
                 )
-                gv: R.Tensor((16,), dtype="float32") = lv2_1
                 R.output(gv)
             return gv
 
@@ -256,10 +255,9 @@ def test_optimize_transform_layout_pass_two_args():
                     (lv3, lv4),
                     out_sinfo=R.Tensor((4, 4), dtype="float32"),
                 )
-                lv6: R.Tensor((16,), dtype="float32") = R.layout_transform(
+                gv: R.Tensor((16,), dtype="float32") = R.layout_transform(
                     lv5, index_map=lambda axis0, axis1: (axis0 * 4 + axis1,), 
pad_value=None
                 )
-                gv: R.Tensor((16,), dtype="float32") = lv6
                 R.output(gv)
             return gv
 
@@ -399,10 +397,9 @@ def test_tranform_layout_tir_remove_pad_transform_layout():
                     pad_value=None,
                     axis_separators=[],
                 )
-                lv_2 = R.call_tir(
+                gv = R.call_tir(
                     Expected.remove_pad, (lv5,), out_sinfo=R.Tensor((14,), 
dtype="float32")
                 )
-                gv: R.Tensor((14,), dtype="float32") = lv_2
                 R.output(gv)
             return gv
 
diff --git a/tests/python/relax/test_remove_redundant_reshape.py 
b/tests/python/relax/test_remove_redundant_reshape.py
index 11e8c87cf1..a28141616c 100644
--- a/tests/python/relax/test_remove_redundant_reshape.py
+++ b/tests/python/relax/test_remove_redundant_reshape.py
@@ -52,8 +52,7 @@ def test_remove_redundant_reshape_pass_one_arg():
             x: R.Tensor((1, 1001, 1, 1), dtype="float16")
         ) -> R.Tensor((1, 1001), dtype="float16"):
             with R.dataflow():
-                lv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, 
R.shape([1, 1001]))
-                gv: R.Tensor((1, 1001), dtype="float16") = lv
+                gv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, 
R.shape([1, 1001]))
                 R.output(gv)
             return gv
 
diff --git a/tests/python/relax/test_transform_canonicalize_bindings.py 
b/tests/python/relax/test_transform_canonicalize_bindings.py
index aed5dad557..92057ce46a 100644
--- a/tests/python/relax/test_transform_canonicalize_bindings.py
+++ b/tests/python/relax/test_transform_canonicalize_bindings.py
@@ -295,10 +295,9 @@ def test_unable_to_fold():
         @R.function
         def main() -> R.Tensor((), "int32"):
             with R.dataflow():
-                y = R.const(1)
+                n = R.const(1)
                 # multiple uses -> cannot coalesce
-                m = R.add(y, y)
-                n = y
+                m = R.add(n, n)
                 R.output(n)
             return n
 
@@ -353,15 +352,244 @@ def test_multiply_used_in_outputs():
         @R.function
         def main() -> R.Tensor((), "int32"):
             with R.dataflow():
-                x = R.const(1)
-                l = x
-                m = x
-                n = x
-                R.output(l, m, n)
+                n = R.const(1)
+                R.output(n)
             return n
 
     verify(UsedInMultipleOutputs, UsedInMultipleOutputs)
 
 
+def test_canonicalize_var_to_dataflow_var_if_legal():
+    """Canonicalize Var to DataflowVar inside DataflowBlock
+
+    DataflowVar instances may only be used inside a DataflowBlock.  If
+    a trivial binding `y = x` occurs, where `x` is a `DataflowVar` and
+    `y` is a `Var`, replacing `y` with `x` may result in usage of a
+    `DataflowVar` outside of a `DataflowBlock`.
+    """
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor):
+            with R.dataflow():
+                y = R.add(x, R.const(1))
+                z = R.add(y, R.const(1))
+                R.output(y, z)
+            return z
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor):
+            with R.dataflow():
+                y = R.add(x, R.const(1))
+                z = R.add(y, R.const(1))
+                R.output(z)
+            return z
+
+    after = relax.transform.CanonicalizeBindings()(Before)
+    assert_structural_equal(Expected, after)
+
+
+def test_update_dataflow_computations_if_var_replacement_occurs():
+    """Canonicalize Var to DataflowVar inside DataflowBlock
+
+    DataflowBlocks may produce additional outputs after the first
+    output Var, and these additional outputs may be in terms of the
+    first output.  Computations that depend on a replaced var must be
+    updated to remain well-formed.
+    """
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor):
+            with R.dataflow():
+                lv1 = R.add(x, R.const(1))
+                gv1 = lv1
+                gv2 = R.add(lv1, R.const(1))
+                R.output(gv1, gv2)
+            return (gv1, gv2)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor):
+            with R.dataflow():
+                # lv1 has been replaced with gv1
+                gv1 = R.add(x, R.const(1))
+                # So gv1 must be used in the computation of gv2
+                gv2 = R.add(gv1, R.const(1))
+                R.output(gv1, gv2)
+            return (gv1, gv2)
+
+    after = relax.transform.CanonicalizeBindings()(Before)
+    assert_structural_equal(Expected, after)
+
+
+def test_update_dataflow_computations_if_var_replacement_occurs_after_usage():
+    """Canonicalize Var to DataflowVar inside DataflowBlock
+
+    Like test_update_dataflow_computations_if_var_replacement_occurs,
+    but the usage of a DataflowVar occurs before the trivial binding
+    that causes it to be replaced.
+    """
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor):
+            with R.dataflow():
+                lv1 = R.add(x, R.const(1))
+                gv2 = R.add(lv1, R.const(1))
+                gv1 = lv1
+                R.output(gv1, gv2)
+            return (gv1, gv2)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor):
+            with R.dataflow():
+                # lv1 has been replaced with gv1
+                gv1 = R.add(x, R.const(1))
+                # So gv1 must be used in the computation of gv2
+                gv2 = R.add(gv1, R.const(1))
+                # Even though the trivial binding of "gv1 = lv1"
+                # occurred in this position.
+                R.output(gv1, gv2)
+            return (gv1, gv2)
+
+    after = relax.transform.CanonicalizeBindings()(Before)
+    assert_structural_equal(Expected, after)
+
+
+def test_canonicalize_trivial_binding_to_dataflow_var():
+    """Canonicalize Var to DataflowVar inside DataflowBlock
+
+    DataflowVar instances may only be used inside a DataflowBlock.  If
+    a trivial binding `y = x` occurs, where `x` is a `DataflowVar` and
+    `y` is a `Var`, replacing `y` with `x` may result in usage of a
+    `DataflowVar` outside of a `DataflowBlock`.
+
+    If a binding exists solely to convert from DataflowVar into Var,
+    then canonicalization replaces the earlier DataflowVar with a Var.
+    """
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor):
+            with R.dataflow():
+                y = R.add(x, R.const(1))
+                z = y
+                R.output(z)
+            return z
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor):
+            with R.dataflow():
+                y = R.add(x, R.const(1))
+                R.output(y)
+            return y
+
+    after = relax.transform.CanonicalizeBindings()(Before)
+    assert_structural_equal(Expected, after)
+
+
+def test_canonicalize_multiple_trivial_binding_to_dataflow_var():
+    """Canonicalize Var to DataflowVar inside DataflowBlock
+
+    Like test_canonicalize_trivial_binding_to_dataflow_var, but there
+    exist multiple trivial bindings to the DataflowVar.
+    """
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(w: R.Tensor):
+            with R.dataflow():
+                x = R.add(w, R.const(1))
+                y = x
+                z = x
+                R.output(y, z)
+            return (y, z)
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(w: R.Tensor):
+            with R.dataflow():
+                x = R.add(w, R.const(1))
+                R.output(x)
+            return (x, x)
+
+    after = relax.transform.CanonicalizeBindings()(Before)
+    assert_structural_equal(Expected, after)
+
+
+def test_canonicalize_trivial_var_binding_inside_dataflow_block():
+    """Canonicalize Var to DataflowVar inside DataflowBlock
+
+    Canonicalization handles cases where a Var could be replaced by a
+    DataflowVar, and where a Var is a trivial binding.  If these two
+    cases both occur, should produce reasonable results.
+    """
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor):
+            with R.dataflow():
+                y = R.add(x, R.const(1))
+                z = y
+                R.output(y, z)
+            return z
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor):
+            with R.dataflow():
+                y = R.add(x, R.const(1))
+                R.output(y)
+            return y
+
+    after = relax.transform.CanonicalizeBindings()(Before)
+    assert_structural_equal(Expected, after)
+
+
+def test_canonicalize_across_non_dataflow_tuple():
+    """Canonicalize Var to DataflowVar inside DataflowBlock"""
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor):
+            with R.dataflow():
+                y = R.add(x, R.const(1))
+                z = (y,)
+                gv = R.add(z[0], R.const(1))
+                R.output(z, gv)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor):
+            with R.dataflow():
+                y = R.add(x, R.const(1))
+                z = (y,)
+                gv = R.add(y, R.const(1))
+                R.output(gv)
+            return gv
+
+    after = relax.transform.CanonicalizeBindings()(Before)
+    assert_structural_equal(Expected, after)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to