jroesch commented on a change in pull request #7029:
URL: https://github.com/apache/tvm/pull/7029#discussion_r535741593



##########
File path: tests/python/relay/test_pass_dead_code_elimination.py
##########
@@ -25,59 +25,106 @@
 import pytest
 
 
-class env:
-    def __init__(self):
-        self.shape = tvm.runtime.convert([1, 2, 3])
-        self.tt = relay.TensorType(self.shape, "float32")
-        self.int32 = relay.TensorType([], "int32")
-        self.float32 = relay.TensorType([], "float32")
-        self.one = relay.const(1.0)
-        self.two = relay.const(2.0)
-        self.three = relay.const(3.0)
-        self.a = relay.Var("a", self.float32)
-        self.b = relay.Var("b", self.float32)
-        self.c = relay.Var("c", self.float32)
-        self.d = relay.Var("d", self.float32)
-        self.e = relay.Var("e", self.float32)
-        self.x = relay.Var("x", self.int32)
-        self.y = relay.Var("y", self.int32)
-        self.z = relay.Var("z", self.int32)
-
-
-e = env()
-
-
-def run_opt_pass(expr, opt_pass):
-    assert isinstance(opt_pass, tvm.transform.Pass)
-    mod = tvm.IRModule.from_expr(expr)
-    mod = opt_pass(mod)
-    entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
-
-
-def test_let():
-    orig = relay.Let(e.x, e.y, e.z)
-    orig = run_opt_pass(orig, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(orig), orig), 
Function([e.z], e.z))
-
-
-def test_used_let():
-    orig = relay.Let(e.c, e.one, e.c + e.c)
-    orig = run_opt_pass(orig, transform.DeadCodeElimination())
-    expected = relay.Let(e.c, e.one, e.c + e.c)
-    assert tvm.ir.structural_equal(Function([], orig), Function([], expected))
+# class env:
+#     def __init__(self):
+#         self.shape = tvm.runtime.convert([1, 2, 3])
+#         self.tt = relay.TensorType(self.shape, "float32")
+#         self.int32 = relay.TensorType([], "int32")
+#         self.float32 = relay.TensorType([], "float32")
+#         self.one = relay.const(1.0)
+#         self.two = relay.const(2.0)
+#         self.three = relay.const(3.0)
+#         self.a = relay.Var("a", self.float32)
+#         self.b = relay.Var("b", self.float32)
+#         self.c = relay.Var("c", self.float32)
+#         self.d = relay.Var("d", self.float32)
+#         self.e = relay.Var("e", self.float32)
+#         self.x = relay.Var("x", self.int32)
+#         self.y = relay.Var("y", self.int32)
+#         self.z = relay.Var("z", self.int32)
+
+
+# e = env()
+
+
+# def run_opt_pass(expr, opt_pass):
+#     assert isinstance(opt_pass, tvm.transform.Pass)
+#     mod = tvm.IRModule.from_expr(expr)
+#     mod = opt_pass(mod)
+#     entry = mod["main"]
+#     return entry if isinstance(expr, relay.Function) else entry.body
+
+
+def optimize_source(source, passes):
+    if not isinstance(passes, list):
+        passes = [passes]
+
+    optimize = tvm.transform.Sequential(passes)
+    module = tvm.parser.parse(source)
+    return optimize(module)
+
+
+def optimize_and_check(before_source, after_source, passes):
+    optimize_module = optimize_source(before_source, passes)
+    after_module = tvm.parser.parse(after_source)
+    print(optimize_module)
+    print(after_module)
+    assert tvm.ir.structural_equal(after_module, optimize_module)
+
+
+def test_dead_let():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main(%z: int) {
+        let %x = 1;
+        %z
+    }
+    """
+    after_program = """
+    #[version = "0.0.5"]
+    def @main(%z: int) {
+        %z
+    }
+    """
+    optimize_and_check(before_program, after_program, 
transform.DeadCodeElimination())
 
 
-def test_inline():
-    orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
-    orig = run_opt_pass(orig, transform.DeadCodeElimination(True))
-    tvm.ir.assert_structural_equal(Function(free_vars(orig), orig), 
Function([e.d], e.d))
+def test_one_live_let():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main(%z: int) {
+        let %x = 1;
+        let %y = 2;
+        %x + %x
+    }
+    """
+    after_program = """
+    #[version = "0.0.5"]
+    def @main(%z: int) {
+        let %x = 1;
+        %x + %x
+    }
+    """
+    optimize_and_check(before_program, after_program, 
transform.DeadCodeElimination())
 
 
-def test_chain_unused_let():
-    orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
-    orig = run_opt_pass(orig, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(orig), orig), 
Function([e.e], e.e))
+def test_nested_let():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main(%d: int, %b: int) {
+        let %a = %b;
+        let %c = %d;
+        %c
+    }
+    """
+    after_program = """
+    #[version = "0.0.5"]
+    def @main(%d: int, %b: int) {
+        let %c = %d;

Review comment:
       It looks like current liveness keeps the binding around unless the 
constructed IR is slightly different. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to