This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 6d97b95eed [Fix] Fix the purity flag of "vm.call_tir_dyn" and "kill" ops (#16773) 6d97b95eed is described below commit 6d97b95eed4f1c76ef945bb6a5f38639f0f97a6c Author: Ruihang Lai <ruiha...@cs.cmu.edu> AuthorDate: Sun Mar 24 14:07:23 2024 -0400 [Fix] Fix the purity flag of "vm.call_tir_dyn" and "kill" ops (#16773) This PR fixes the purity flag of `relax.vm.call_tir_dyn` and another few "kill" ops. Their purity flags were set to True, which made them possible to be removed by `remove_all_unused`. * `relax.vm.call_tir_dyn` works by mutating the input args in place, which is not pure. * though the "kill" ops have no actions so far, their semantics suggest that they are impure. A regression test is added to prevent the unexpected removal from happening again. --- src/relax/op/op.cc | 15 ++++++----- tests/python/relax/test_analysis.py | 42 +++++++++++++++++++++++------ tests/python/relax/test_transform_cse.py | 2 +- tests/python/relax/test_tvmscript_parser.py | 2 +- 4 files changed, 44 insertions(+), 17 deletions(-) diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc index efbf648b48..7eb499f102 100644 --- a/src/relax/op/op.cc +++ b/src/relax/op/op.cc @@ -921,8 +921,8 @@ RELAY_REGISTER_OP("relax.memory.kill_storage") .set_num_inputs(1) .add_argument("storage", "Expr", "The storage to be killed.") .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo) - // deallocation also isn't considered a "visible effect" as far as purity is concerned - .set_attr<Bool>("FPurity", Bool(true)); + // We mark this as impure so it wouldn't be removed by "remove_all_unused" + .set_attr<Bool>("FPurity", Bool(false)); Expr MakeMemKillStorage(Expr storage) { static const Op& op = Op::Get("relax.memory.kill_storage"); @@ -937,8 +937,8 @@ RELAY_REGISTER_OP("relax.memory.kill_tensor") .set_num_inputs(1) .add_argument("tensor", "Expr", "The tensor to be killed.") .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo) - // memory deallocation also isn't considered a "visible effect" as far as purity is concerned - .set_attr<Bool>("FPurity", Bool(true)); + // We mark this as impure so it wouldn't be removed by "remove_all_unused" + .set_attr<Bool>("FPurity", Bool(false)); Expr MakeMemKillTensor(Expr tensor) { static const Op& op = Op::Get("relax.memory.kill_tensor"); @@ -1013,8 +1013,8 @@ TVM_REGISTER_OP("relax.vm.kill_object") .set_num_inputs(1) .add_argument("obj", "Expr", "The object to be killed.") .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo) - // deallocation also isn't considered a "visible effect" as far as purity is concerned - .set_attr<Bool>("FPurity", Bool(true)); + // We mark this as impure so it wouldn't be removed by "remove_all_unused" + .set_attr<Bool>("FPurity", Bool(false)); Expr MakeVMKillObject(Expr obj) { static const Op& op = Op::Get("relax.vm.kill_object"); @@ -1031,7 +1031,8 @@ RELAY_REGISTER_OP("relax.vm.call_tir_dyn") .add_argument("args", "Tuple", "The input arguments (list of tensors and last argument is ShapeExpr)") .set_attr<FInferStructInfo>("FInferStructInfo", ReturnVoidStructInfo) - .set_attr<Bool>("FPurity", Bool(true)); + // "relax.vm.call_tir_dyn" works in an in-place way, which is impure. + .set_attr<Bool>("FPurity", Bool(false)); Expr MakeCallTIRDyn(Expr func, Tuple args) { static const Op& op = Op::Get("relax.vm.call_tir_dyn"); diff --git a/tests/python/relax/test_analysis.py b/tests/python/relax/test_analysis.py index 28ca13ad89..c790b1bc51 100644 --- a/tests/python/relax/test_analysis.py +++ b/tests/python/relax/test_analysis.py @@ -19,19 +19,21 @@ from typing import List, Set, Union import tvm import tvm.testing -from tvm import tir from tvm import relax as rx +from tvm import tir from tvm.relax.analysis import ( - has_reshape_pattern, - udchain, - remove_all_unused, - name_to_binding, - all_vars, all_global_vars, - free_vars, + all_vars, bound_vars, + free_vars, + has_reshape_pattern, + name_to_binding, + remove_all_unused, + udchain, ) -from tvm.script import relax as R, tir as T +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T def var_name_set(vars: List[Union[rx.Var, rx.GlobalVar]]) -> Set[str]: @@ -352,6 +354,30 @@ def test_retain_impure_calls_unused_in_binding_block(): tvm.ir.assert_structural_equal(expected.body, after, map_free_vars=True) +def test_retain_calls_to_impure_builtin_ops(): + @I.ir_module + class Module: + @T.prim_func(private=True) + def my_tir(A: T.handle, B: T.handle, n: T.int64): + T.evaluate(0) + + @R.function(pure=False) + def main(x: R.Tensor(("n",), "float32")): + cls = Module + n = T.int64() + storage = R.memory.alloc_storage((n * 4,), 0, "global", "float32") + alloc = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([n]), "float32") + # "call_tir_dyn" is impure which shouldn't be removed. + R.vm.call_tir_dyn(cls.my_tir, (x, alloc, R.shape([n]))) + # "kill_tensor"/"kill_storage" are impure which shouldn't be removed. + R.memory.kill_tensor(alloc) + R.memory.kill_storage(storage) + return x + + after = remove_all_unused(Module["main"]) + tvm.ir.assert_structural_equal(after, Module["main"], map_free_vars=True) + + def test_name_to_binding_var_shadowing(): @R.function def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: diff --git a/tests/python/relax/test_transform_cse.py b/tests/python/relax/test_transform_cse.py index b491577314..0998fb67c0 100644 --- a/tests/python/relax/test_transform_cse.py +++ b/tests/python/relax/test_transform_cse.py @@ -435,7 +435,7 @@ def test_call_tir_tuple_arg(): def test_do_not_eliminate_dtype(): @I.ir_module class Before: - @R.function + @R.function(pure=False) def foo() -> R.Tensor((32, 64), "int32"): obj: R.Object = R.vm.alloc_storage( R.shape([24576]), runtime_device_index=0, dtype="uint8" diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 3f806de28d..109971ce37 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1552,7 +1552,7 @@ def test_memory_ops(): def test_vm_ops(): - @R.function + @R.function(pure=False) def foo(x: R.Tensor(("m", "n"), dtype="float32")): m = T.int64() n = T.int64()