This is an automated email from the ASF dual-hosted git repository. lunderberg pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new b91d4e55b3 [TVMScript] Produce empty DictAttrs when R.func_attrs is absent (#16844) b91d4e55b3 is described below commit b91d4e55b3f66a10508b4b492378173be75ba1a5 Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Fri Apr 5 07:21:59 2024 -0500 [TVMScript] Produce empty DictAttrs when R.func_attrs is absent (#16844) A follow-up to https://github.com/apache/tvm/pull/16745. For Relax functions produced in TVMScript, when `R.func_attrs` was not present, the default was set to `None` instead of an empty dictionary. --- src/relax/ir/expr.cc | 4 ++++ src/script/ir_builder/relax/frame.cc | 3 +-- src/tir/ir/function.cc | 4 ++++ tests/python/relax/test_tvmscript_parser.py | 22 ++++++++++++++++++++++ 4 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index b709039e8c..1b5551e509 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -493,6 +493,10 @@ TVM_REGISTER_NODE_TYPE(FunctionNode); Function::Function(Array<Var> params, Expr body, Optional<StructInfo> ret_struct_info, bool is_pure, DictAttrs attrs, Span span) { + if (!attrs.defined()) { + attrs = DictAttrs(); + } + // Set the function type. // For function, we take a conservative approach and require the function type // to be known at construction time. diff --git a/src/script/ir_builder/relax/frame.cc b/src/script/ir_builder/relax/frame.cc index b95db57a88..792331dda4 100644 --- a/src/script/ir_builder/relax/frame.cc +++ b/src/script/ir_builder/relax/frame.cc @@ -61,13 +61,12 @@ void FunctionFrameNode::ExitWithScope() { !attrs.count(tvm::attr::kGlobalSymbol)) { attrs.Set(tvm::attr::kGlobalSymbol, name.value()); } - auto dict_attrs = attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs); this->block_builder->EndScope(); tvm::relax::Function func(/*params=*/params, /*body=*/body, /*ret_struct_info=*/ret_struct_info, /*is_pure=*/is_pure.value_or(Bool(true))->value, - /*attrs=*/dict_attrs); + /*attrs=*/DictAttrs(attrs)); // Step 2: Update IRModule. if (builder->frames.empty()) { // Case 0. No outer frame, return function directly diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc index 8a3d2d6947..14dd0eadb6 100644 --- a/src/tir/ir/function.cc +++ b/src/tir/ir/function.cc @@ -70,6 +70,10 @@ relax::StructInfo InferStructInfo(const PrimFunc& prim_func) { // Get the function type of a PrimFunc PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type, Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span span) { + if (!attrs.defined()) { + attrs = DictAttrs(); + } + // Assume void-return type for now // TODO(tvm-team) consider type deduction from body. if (!ret_type.defined()) { diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index c8db26c81b..e692768a12 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -2271,5 +2271,27 @@ def test_define_relax_function_using_global_var(): tvm.ir.assert_structural_equal(DefinedAllAtOnce, MainDefinedLater) +def test_function_attributes_are_defined(): + """func.attrs defaults to an empty DictAttrs""" + + @I.ir_module + class Module: + @R.function + def main(x: R.Tensor, shape: R.Shape(["m", "n"])): + output = Module.subroutine(x, shape) + return output + + @R.function + def subroutine(x: R.Tensor, _: R.Shape(["m", "n"])) -> R.Tensor(["m", "n"]): + q = x + m, n = T.int64(), T.int64() + z = R.match_cast(q, R.Tensor((m, n))) + w = z + return w + + for gvar, func in Module.functions.items(): + assert func.attrs is not None + + if __name__ == "__main__": tvm.testing.main()