This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bf2d43e314 [IR][Relax] Improve highlighting in assert_structural_equal 
(#16756)
bf2d43e314 is described below

commit bf2d43e314ca7e682ae26dca70ada657054f8786
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Tue Mar 26 08:26:52 2024 -0500

    [IR][Relax] Improve highlighting in assert_structural_equal (#16756)
    
    * [IR][Relax] Improve highlighting in assert_structural_equal
    
    Prior to this commit, `tvm.ir.assert_structural_equal` would highlight
    an entire `relax::BindingBlock` if the number of elements in the
    binding block differs.  This can result in the entire Relax function
    being highlighted, making it difficult to identify the location of the
    mismatch.
    
    This commit makes the following changes, to improve the error messages
    that occur when `tvm.ir.assert_structural_equal` raises an exception.
    
    - In `"node.StructuralEqual"`, set `defer_fails = true` when
      `assert_mode` is true.  This highlights the first mismatch of an
      `Array<relax::Binding>`, rather than the entire array, in cases
      where the LHS and RHS have different sizes.
    
    - In the `SHashReduce` for `VarBinding` and `MatchCast`, visit the
      value first, and then the variable to which it is bound.  This
      highlights the mismatched expression, rather than mismatches in the
      resulting struct info.
    
    - In `SEqualHandlerDefault::Impl::SEqualReduce`, defer the failure if
      enabled.  This highlights the first mismatch, which may also have been
      deferred, rather than an early return a later mismatch occurs
      involving `NullOpt`.
    
    * DeferFail should follow assert_mode
    
    * Handle recursively defined lambda functions
---
 include/tvm/relax/expr.h         | 24 ++++-----------
 src/node/structural_equal.cc     | 45 +++++++++++++++++++---------
 src/relax/ir/expr.cc             | 50 +++++++++++++++++++++++++++++++
 tests/python/relax/test_utils.py | 63 +++++++++++++++++++++++++++++++++++++++-
 4 files changed, 149 insertions(+), 33 deletions(-)

diff --git a/include/tvm/relax/expr.h b/include/tvm/relax/expr.h
index 4634d1e228..40707675fe 100644
--- a/include/tvm/relax/expr.h
+++ b/include/tvm/relax/expr.h
@@ -780,18 +780,8 @@ class MatchCastNode : public BindingNode {
     v->Visit("span", &span);
   }
 
-  bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const {
-    // NOTE: pattern can contain ShapeExpr which defines the vars
-    return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, 
other->struct_info) &&
-           equal(value, other->value);
-  }
-
-  void SHashReduce(SHashReducer hash_reduce) const {
-    // NOTE: pattern can contain ShapeExpr which defines the vars
-    hash_reduce.DefHash(var);
-    hash_reduce.DefHash(struct_info);
-    hash_reduce(value);
-  }
+  bool SEqualReduce(const MatchCastNode* other, SEqualReducer equal) const;
+  void SHashReduce(SHashReducer hash_reduce) const;
 
   static constexpr const char* _type_key = "relax.expr.MatchCast";
   static constexpr const bool _type_has_method_sequal_reduce = true;
@@ -822,13 +812,9 @@ class VarBindingNode : public BindingNode {
     v->Visit("span", &span);
   }
 
-  bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const {
-    return equal.DefEqual(var, other->var) && equal(value, other->value);
-  }
-  void SHashReduce(SHashReducer hash_reduce) const {
-    hash_reduce.DefHash(var);
-    hash_reduce(value);
-  }
+  bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const;
+  void SHashReduce(SHashReducer hash_reduce) const;
+
   static constexpr const char* _type_key = "relax.expr.VarBinding";
   static constexpr const bool _type_has_method_sequal_reduce = true;
   static constexpr const bool _type_has_method_shash_reduce = true;
diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc
index 66a347f6b8..e0de514122 100644
--- a/src/node/structural_equal.cc
+++ b/src/node/structural_equal.cc
@@ -27,6 +27,7 @@
 #include <tvm/node/structural_equal.h>
 #include <tvm/runtime/registry.h>
 
+#include <optional>
 #include <unordered_map>
 
 #include "ndarray_hash_equal.h"
@@ -249,15 +250,30 @@ class SEqualHandlerDefault::Impl {
     // in which case we can use same_as for quick checking,
     // or we have to run deep comparison and avoid to use same_as checks.
     auto run = [=]() {
-      if (!lhs.defined() && !rhs.defined()) return true;
-      if (!lhs.defined() && rhs.defined()) return false;
-      if (!rhs.defined() && lhs.defined()) return false;
-      if (lhs->type_index() != rhs->type_index()) return false;
-      auto it = equal_map_lhs_.find(lhs);
-      if (it != equal_map_lhs_.end()) {
-        return it->second.same_as(rhs);
+      std::optional<bool> early_result = [&]() -> std::optional<bool> {
+        if (!lhs.defined() && !rhs.defined()) return true;
+        if (!lhs.defined() && rhs.defined()) return false;
+        if (!rhs.defined() && lhs.defined()) return false;
+        if (lhs->type_index() != rhs->type_index()) return false;
+        auto it = equal_map_lhs_.find(lhs);
+        if (it != equal_map_lhs_.end()) {
+          return it->second.same_as(rhs);
+        }
+        if (equal_map_rhs_.count(rhs)) return false;
+
+        return std::nullopt;
+      }();
+
+      if (early_result.has_value()) {
+        if (early_result.value()) {
+          return true;
+        } else if (IsPathTracingEnabled() && IsFailDeferralEnabled() && 
current_paths.defined()) {
+          DeferFail(current_paths.value());
+          return true;
+        } else {
+          return false;
+        }
       }
-      if (equal_map_rhs_.count(rhs)) return false;
 
       // need to push to pending tasks in this case
       pending_tasks_.emplace_back(lhs, rhs, map_free_vars, current_paths);
@@ -388,10 +404,7 @@ class SEqualHandlerDefault::Impl {
       auto& entry = task_stack_.back();
 
       if (entry.force_fail) {
-        if (IsPathTracingEnabled() && !first_mismatch_->defined()) {
-          *first_mismatch_ = entry.current_paths;
-        }
-        return false;
+        return CheckResult(false, entry.lhs, entry.rhs, entry.current_paths);
       }
 
       if (entry.children_expanded) {
@@ -530,8 +543,14 @@ bool SEqualHandlerDefault::DispatchSEqualReduce(const 
ObjectRef& lhs, const Obje
 TVM_REGISTER_GLOBAL("node.StructuralEqual")
     .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool 
assert_mode,
                        bool map_free_vars) {
+      // If we are asserting on failure, then the `defer_fails` option
+      // should be enabled, to provide better error messages.  For
+      // example, if the number of bindings in a `relax::BindingBlock`
+      // differs, highlighting the first difference rather than the
+      // entire block.
+      bool defer_fails = assert_mode;
       Optional<ObjectPathPair> first_mismatch;
-      return SEqualHandlerDefault(assert_mode, &first_mismatch, false)
+      return SEqualHandlerDefault(assert_mode, &first_mismatch, defer_fails)
           .Equal(lhs, rhs, map_free_vars);
     });
 
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index 1bc7267af6..b709039e8c 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -384,6 +384,33 @@ TVM_REGISTER_GLOBAL("relax.MatchCast")
       return MatchCast(var, value, struct_info, span);
     });
 
+bool MatchCastNode::SEqualReduce(const MatchCastNode* other, SEqualReducer 
equal) const {
+  if (value->IsInstance<FunctionNode>()) {
+    // Recursive function definitions may reference the bound variable
+    // within the value being bound.  In these cases, the
+    // `DefEqual(var, other->var)` must occur first, to ensure it is
+    // defined at point of use.
+    return equal.DefEqual(var, other->var) && equal.DefEqual(struct_info, 
other->struct_info) &&
+           equal(value, other->value);
+  } else {
+    // In all other cases, visit the bound value before the variable
+    // it is bound to, in order to provide better error messages.
+    return equal(value, other->value) && equal.DefEqual(struct_info, 
other->struct_info) &&
+           equal.DefEqual(var, other->var);
+  }
+}
+void MatchCastNode::SHashReduce(SHashReducer hash_reduce) const {
+  if (value->IsInstance<FunctionNode>()) {
+    hash_reduce.DefHash(var);
+    hash_reduce.DefHash(struct_info);
+    hash_reduce(value);
+  } else {
+    hash_reduce(value);
+    hash_reduce.DefHash(struct_info);
+    hash_reduce.DefHash(var);
+  }
+}
+
 TVM_REGISTER_NODE_TYPE(VarBindingNode);
 
 VarBinding::VarBinding(Var var, Expr value, Span span) {
@@ -398,6 +425,29 @@ 
TVM_REGISTER_GLOBAL("relax.VarBinding").set_body_typed([](Var var, Expr value, S
   return VarBinding(var, value, span);
 });
 
+bool VarBindingNode::SEqualReduce(const VarBindingNode* other, SEqualReducer 
equal) const {
+  if (value->IsInstance<FunctionNode>()) {
+    // Recursive function definitions may reference the bound variable
+    // within the value being bound.  In these cases, the
+    // `DefEqual(var, other->var)` must occur first, to ensure it is
+    // defined at point of use.
+    return equal.DefEqual(var, other->var) && equal(value, other->value);
+  } else {
+    // In all other cases, visit the bound value before the variable
+    // it is bound to, in order to provide better error messages.
+    return equal(value, other->value) && equal.DefEqual(var, other->var);
+  }
+}
+void VarBindingNode::SHashReduce(SHashReducer hash_reduce) const {
+  if (value->IsInstance<FunctionNode>()) {
+    hash_reduce.DefHash(var);
+    hash_reduce(value);
+  } else {
+    hash_reduce(value);
+    hash_reduce.DefHash(var);
+  }
+}
+
 TVM_REGISTER_NODE_TYPE(BindingBlockNode);
 
 BindingBlock::BindingBlock(Array<Binding> bindings, Span span) {
diff --git a/tests/python/relax/test_utils.py b/tests/python/relax/test_utils.py
index 0cae5101a7..9abc53484b 100644
--- a/tests/python/relax/test_utils.py
+++ b/tests/python/relax/test_utils.py
@@ -14,12 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+
+import re
+
 import pytest
 
 import tvm
 from tvm import relax
 from tvm.ir.base import assert_structural_equal
-from tvm.script.parser import relax as R
+from tvm.script.parser import relax as R, tir as T
 
 
 def test_copy_with_new_vars():
@@ -122,6 +125,27 @@ def test_copy_with_new_vars_on_ir_module_nested_function():
     assert_structural_equal(Actual, Expected)
 
 
+def test_assert_structural_equal_in_seqexpr():
+    """The first mismatch is correctly identified."""
+
+    @R.function(private=True)
+    def func_1(A: R.Tensor([16, 16], "float32")):
+        B = R.concat([A, A])
+        return B
+
+    @R.function(private=True)
+    def func_2(A: R.Tensor([16, 16], "float32")):
+        B = R.add(A, A)
+        C = R.add(B, B)
+        return B
+
+    with pytest.raises(
+        ValueError,
+        match=re.escape("<root>.body.blocks[0].bindings[0].value.op"),
+    ):
+        assert_structural_equal(func_1, func_2)
+
+
 def test_structural_equal_of_call_nodes():
     """relax.Call must be compared by structural equality, not reference"""
 
@@ -145,5 +169,42 @@ def test_structural_equal_of_call_nodes():
     tvm.ir.assert_structural_equal(uses_same_object_twice, 
uses_two_different_objects)
 
 
+def test_structural_equal_with_recursive_lambda_function():
+    """A recursive lambda function may be checked for structural equality
+
+    Recursive function definitions may reference the bound variable
+    within the value being bound.  In these cases, the `DefEqual(var,
+    other->var)` must occur first, to ensure it is defined at point of
+    use.
+
+    In all other cases, checking for structural equality of the bound
+    value prior to the variable provides a better error message.
+    """
+
+    def define_function():
+        @R.function
+        def func(n: R.Prim("int64")):
+            @R.function
+            def recursive_lambda(i_arg: R.Prim(value="i")) -> R.Prim("int64"):
+                i = T.int64()
+                if R.prim_value(i == 0):
+                    output = R.prim_value(T.int64(0))
+                else:
+                    remainder_relax = recursive_lambda(R.prim_value(i - 1))
+                    remainder_tir = T.int64()
+                    _ = R.match_cast(remainder_relax, 
R.Prim(value=remainder_tir))
+                    output = R.prim_value(i + remainder_tir)
+                return output
+
+            return recursive_lambda(n)
+
+        return func
+
+    func_1 = define_function()
+    func_2 = define_function()
+
+    tvm.ir.assert_structural_equal(func_1, func_2)
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to