This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6d97b95eed [Fix] Fix the purity flag of "vm.call_tir_dyn" and "kill" 
ops (#16773)
6d97b95eed is described below

commit 6d97b95eed4f1c76ef945bb6a5f38639f0f97a6c
Author: Ruihang Lai <ruiha...@cs.cmu.edu>
AuthorDate: Sun Mar 24 14:07:23 2024 -0400

    [Fix] Fix the purity flag of "vm.call_tir_dyn" and "kill" ops (#16773)
    
    This PR fixes the purity flag of `relax.vm.call_tir_dyn` and another
    few "kill" ops. Their purity flags were set to True, which made them
    possible to be removed by `remove_all_unused`.
    
    * `relax.vm.call_tir_dyn` works by mutating the input args in place,
    which is not pure.
    * though the "kill" ops have no actions so far, their semantics
    suggest that they are impure.
    
    A regression test is added to prevent the unexpected removal from
    happening again.
---
 src/relax/op/op.cc                          | 15 ++++++-----
 tests/python/relax/test_analysis.py         | 42 +++++++++++++++++++++++------
 tests/python/relax/test_transform_cse.py    |  2 +-
 tests/python/relax/test_tvmscript_parser.py |  2 +-
 4 files changed, 44 insertions(+), 17 deletions(-)

diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index efbf648b48..7eb499f102 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -921,8 +921,8 @@ RELAY_REGISTER_OP("relax.memory.kill_storage")
     .set_num_inputs(1)
     .add_argument("storage", "Expr", "The storage to be killed.")
     .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
-    // deallocation also isn't considered a "visible effect" as far as purity 
is concerned
-    .set_attr<Bool>("FPurity", Bool(true));
+    // We mark this as impure so it wouldn't be removed by "remove_all_unused"
+    .set_attr<Bool>("FPurity", Bool(false));
 
 Expr MakeMemKillStorage(Expr storage) {
   static const Op& op = Op::Get("relax.memory.kill_storage");
@@ -937,8 +937,8 @@ RELAY_REGISTER_OP("relax.memory.kill_tensor")
     .set_num_inputs(1)
     .add_argument("tensor", "Expr", "The tensor to be killed.")
     .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
-    // memory deallocation also isn't considered a "visible effect" as far as 
purity is concerned
-    .set_attr<Bool>("FPurity", Bool(true));
+    // We mark this as impure so it wouldn't be removed by "remove_all_unused"
+    .set_attr<Bool>("FPurity", Bool(false));
 
 Expr MakeMemKillTensor(Expr tensor) {
   static const Op& op = Op::Get("relax.memory.kill_tensor");
@@ -1013,8 +1013,8 @@ TVM_REGISTER_OP("relax.vm.kill_object")
     .set_num_inputs(1)
     .add_argument("obj", "Expr", "The object to be killed.")
     .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
-    // deallocation also isn't considered a "visible effect" as far as purity 
is concerned
-    .set_attr<Bool>("FPurity", Bool(true));
+    // We mark this as impure so it wouldn't be removed by "remove_all_unused"
+    .set_attr<Bool>("FPurity", Bool(false));
 
 Expr MakeVMKillObject(Expr obj) {
   static const Op& op = Op::Get("relax.vm.kill_object");
@@ -1031,7 +1031,8 @@ RELAY_REGISTER_OP("relax.vm.call_tir_dyn")
     .add_argument("args", "Tuple",
                   "The input arguments (list of tensors and last argument is 
ShapeExpr)")
     .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo)
-    .set_attr<Bool>("FPurity", Bool(true));
+    // "relax.vm.call_tir_dyn" works in an in-place way, which is impure.
+    .set_attr<Bool>("FPurity", Bool(false));
 
 Expr MakeCallTIRDyn(Expr func, Tuple args) {
   static const Op& op = Op::Get("relax.vm.call_tir_dyn");
diff --git a/tests/python/relax/test_analysis.py 
b/tests/python/relax/test_analysis.py
index 28ca13ad89..c790b1bc51 100644
--- a/tests/python/relax/test_analysis.py
+++ b/tests/python/relax/test_analysis.py
@@ -19,19 +19,21 @@ from typing import List, Set, Union
 
 import tvm
 import tvm.testing
-from tvm import tir
 from tvm import relax as rx
+from tvm import tir
 from tvm.relax.analysis import (
-    has_reshape_pattern,
-    udchain,
-    remove_all_unused,
-    name_to_binding,
-    all_vars,
     all_global_vars,
-    free_vars,
+    all_vars,
     bound_vars,
+    free_vars,
+    has_reshape_pattern,
+    name_to_binding,
+    remove_all_unused,
+    udchain,
 )
-from tvm.script import relax as R, tir as T
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
 
 
 def var_name_set(vars: List[Union[rx.Var, rx.GlobalVar]]) -> Set[str]:
@@ -352,6 +354,30 @@ def test_retain_impure_calls_unused_in_binding_block():
     tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True)
 
 
+def test_retain_calls_to_impure_builtin_ops():
+    @I.ir_module
+    class Module:
+        @T.prim_func(private=True)
+        def my_tir(A: T.handle, B: T.handle, n: T.int64):
+            T.evaluate(0)
+
+        @R.function(pure=False)
+        def main(x: R.Tensor(("n",), "float32")):
+            cls = Module
+            n = T.int64()
+            storage = R.memory.alloc_storage((n * 4,), 0, "global", "float32")
+            alloc = R.memory.alloc_tensor(storage, R.prim_value(0), 
R.shape([n]), "float32")
+            # "call_tir_dyn" is impure which shouldn't be removed.
+            R.vm.call_tir_dyn(cls.my_tir, (x, alloc, R.shape([n])))
+            # "kill_tensor"/"kill_storage" are impure which shouldn't be 
removed.
+            R.memory.kill_tensor(alloc)
+            R.memory.kill_storage(storage)
+            return x
+
+    after = remove_all_unused(Module["main"])
+    tvm.ir.assert_structural_equal(after, Module["main"], map_free_vars=True)
+
+
 def test_name_to_binding_var_shadowing():
     @R.function
     def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
diff --git a/tests/python/relax/test_transform_cse.py 
b/tests/python/relax/test_transform_cse.py
index b491577314..0998fb67c0 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -435,7 +435,7 @@ def test_call_tir_tuple_arg():
 def test_do_not_eliminate_dtype():
     @I.ir_module
     class Before:
-        @R.function
+        @R.function(pure=False)
         def foo() -> R.Tensor((32, 64), "int32"):
             obj: R.Object = R.vm.alloc_storage(
                 R.shape([24576]), runtime_device_index=0, dtype="uint8"
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index 3f806de28d..109971ce37 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -1552,7 +1552,7 @@ def test_memory_ops():
 
 
 def test_vm_ops():
-    @R.function
+    @R.function(pure=False)
     def foo(x: R.Tensor(("m", "n"), dtype="float32")):
         m = T.int64()
         n = T.int64()

Reply via email to