This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new fe74b37  Conditions updated to cover better user scenarios (#4951)
fe74b37 is described below

commit fe74b37ab578e6d3c540b0f6ac187a220ccc028a
Author: Tianqi Chen <tqc...@users.noreply.github.com>
AuthorDate: Wed Mar 4 18:35:38 2020 -0600

    Conditions updated to cover better user scenarios (#4951)
    
    * Conditions updated to cover better user scenarios
    
    * [1] New test case added
    
    * [2] New test case added
    
    * [3] Proper variable name used
    
    * [4] Review Comments handled
    
    * [5] Review comments handled
    
    * [6] Review comments handled
---
 src/relay/ir/alpha_equal.cc                 | 10 ++---
 tests/cpp/relay_pass_alpha_equal.cc         | 67 +++++++++++++++++++++++++++++
 tests/python/relay/test_pass_alpha_equal.py | 32 ++++++++++++++
 3 files changed, 104 insertions(+), 5 deletions(-)

diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc
index 78688d7..c622599 100644
--- a/src/relay/ir/alpha_equal.cc
+++ b/src/relay/ir/alpha_equal.cc
@@ -50,14 +50,14 @@ class AlphaEqualHandler:
    * \return The comparison result.
    */
   bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
-    if (lhs.same_as(rhs)) return true;
     if (!lhs.defined() || !rhs.defined()) return false;
-    if (lhs->IsInstance<TypeNode>()) {
-      if (!rhs->IsInstance<TypeNode>()) return false;
+    if (lhs.same_as(rhs)) return true;
+    if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
+      if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return 
false;
       return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
     }
-    if (lhs->IsInstance<ExprNode>()) {
-      if (!rhs->IsInstance<ExprNode>()) return false;
+    if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
+      if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return 
false;
       return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
     }
     if (const auto lhsm = lhs.as<IRModuleNode>()) {
diff --git a/tests/cpp/relay_pass_alpha_equal.cc 
b/tests/cpp/relay_pass_alpha_equal.cc
new file mode 100644
index 0000000..0207fca
--- /dev/null
+++ b/tests/cpp/relay_pass_alpha_equal.cc
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gtest/gtest.h>
+#include <tvm/te/operation.h>
+#include <tvm/relay/expr.h>
+#include <tvm/relay/type.h>
+#include <tvm/relay/analysis.h>
+#include <tvm/relay/transform.h>
+
+using namespace tvm;
+
+class TestAlphaEquals {
+  runtime::PackedFunc *_packed_func;
+ public:
+  TestAlphaEquals(const char* func_name) {
+    _packed_func = new runtime::PackedFunc();
+    TVMFuncGetGlobal(func_name, 
reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
+  }
+
+  void UpdatePackedFunc(const char* func_name) {
+    TVMFuncGetGlobal(func_name, 
reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
+  }
+
+  bool operator()(ObjectRef input_1, ObjectRef input_2) {
+    TVMRetValue rv;
+    std::vector<TVMValue> values(2);
+    std::vector<int> codes(2);
+    runtime::TVMArgsSetter setter(values.data(), codes.data());
+    setter(0, input_1);
+    setter(1, input_2);
+    _packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv);
+    return bool(rv);
+  };
+
+};
+
+TEST(Relay, AlphaTestEmptyTypeNodes) {
+  auto x = TypeVar("x", kTypeData);
+  auto y = TypeVar();
+  EXPECT_FALSE(relay::AlphaEqual(x, y));
+
+  TestAlphaEquals test_equals("relay._make._alpha_equal");
+  EXPECT_FALSE(test_equals(x, y));
+}
+
+int main(int argc, char ** argv) {
+  testing::InitGoogleTest(&argc, argv);
+  testing::FLAGS_gtest_death_test_style = "threadsafe";
+  return RUN_ALL_TESTS();
+}
diff --git a/tests/python/relay/test_pass_alpha_equal.py 
b/tests/python/relay/test_pass_alpha_equal.py
index 7e34f48..ec026be 100644
--- a/tests/python/relay/test_pass_alpha_equal.py
+++ b/tests/python/relay/test_pass_alpha_equal.py
@@ -28,6 +28,15 @@ def alpha_equal(x, y):
     """
     return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == 
analysis.structural_hash(y)
 
+def alpha_equal_commutative(x, y):
+    """
+    Check for commutative property of equality
+    """
+    xy = analysis.alpha_equal(x, y)
+    yx = analysis.alpha_equal(y, x)
+    assert xy == yx
+    return xy
+
 def test_tensor_type_alpha_equal():
     t1 = relay.TensorType((3, 4), "float32")
     t2 = relay.TensorType((3, 4), "float32")
@@ -219,6 +228,26 @@ def test_constant_alpha_equal():
     assert not alpha_equal(x, y)
     assert alpha_equal(x, relay.const(1))
 
+def test_type_node_alpha_equal():
+    v1 = relay.TypeVar('v1', 6)
+    v2 = relay.TypeVar('v2', 6)
+    assert not alpha_equal(v1, v2)
+
+    v1 = relay.TypeVar('v1', 0)
+    v2 = relay.TypeVar('v2', 6)
+    assert not alpha_equal(v1, v2)
+
+    assert alpha_equal_commutative(v1, v1)
+
+def test_type_node_incompatible_alpha_equal():
+    v1 = relay.TypeVar('v1', 6)
+    v2 = relay.Var("v2")
+    assert not alpha_equal_commutative(v1, v2)
+
+def test_expr_node_incompatible_alpha_equal():
+    v1 = relay.Var("v1")
+    v2 = relay.PatternVar(relay.Var("v2"))
+    assert not alpha_equal_commutative(v1, v2)
 
 def test_var_alpha_equal():
     v1 = relay.Var("v1")
@@ -676,6 +705,9 @@ if __name__ == "__main__":
     test_tensor_type_alpha_equal()
     test_incomplete_type_alpha_equal()
     test_constant_alpha_equal()
+    test_type_node_alpha_equal()
+    test_type_node_incompatible_alpha_equal()
+    test_expr_node_incompatible_alpha_equal()
     test_func_type_alpha_equal()
     test_tuple_type_alpha_equal()
     test_type_relation_alpha_equal()

Reply via email to