This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b91d4e55b3 [TVMScript] Produce empty DictAttrs when R.func_attrs is 
absent (#16844)
b91d4e55b3 is described below

commit b91d4e55b3f66a10508b4b492378173be75ba1a5
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Fri Apr 5 07:21:59 2024 -0500

    [TVMScript] Produce empty DictAttrs when R.func_attrs is absent (#16844)
    
    A follow-up to https://github.com/apache/tvm/pull/16745.  For Relax
    functions produced in TVMScript, when `R.func_attrs` was not present,
    the default was set to `None` instead of an empty dictionary.
---
 src/relax/ir/expr.cc                        |  4 ++++
 src/script/ir_builder/relax/frame.cc        |  3 +--
 src/tir/ir/function.cc                      |  4 ++++
 tests/python/relax/test_tvmscript_parser.py | 22 ++++++++++++++++++++++
 4 files changed, 31 insertions(+), 2 deletions(-)

diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index b709039e8c..1b5551e509 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -493,6 +493,10 @@ TVM_REGISTER_NODE_TYPE(FunctionNode);
 
 Function::Function(Array<Var> params, Expr body, Optional<StructInfo> 
ret_struct_info, bool is_pure,
                    DictAttrs attrs, Span span) {
+  if (!attrs.defined()) {
+    attrs = DictAttrs();
+  }
+
   // Set the function type.
   // For function, we take a conservative approach and require the function 
type
   // to be known at construction time.
diff --git a/src/script/ir_builder/relax/frame.cc 
b/src/script/ir_builder/relax/frame.cc
index b95db57a88..792331dda4 100644
--- a/src/script/ir_builder/relax/frame.cc
+++ b/src/script/ir_builder/relax/frame.cc
@@ -61,13 +61,12 @@ void FunctionFrameNode::ExitWithScope() {
       !attrs.count(tvm::attr::kGlobalSymbol)) {
     attrs.Set(tvm::attr::kGlobalSymbol, name.value());
   }
-  auto dict_attrs = attrs.empty() ? NullValue<DictAttrs>() : DictAttrs(attrs);
   this->block_builder->EndScope();
   tvm::relax::Function func(/*params=*/params,
                             /*body=*/body,
                             /*ret_struct_info=*/ret_struct_info,
                             /*is_pure=*/is_pure.value_or(Bool(true))->value,
-                            /*attrs=*/dict_attrs);
+                            /*attrs=*/DictAttrs(attrs));
   // Step 2: Update IRModule.
   if (builder->frames.empty()) {
     // Case 0. No outer frame, return function directly
diff --git a/src/tir/ir/function.cc b/src/tir/ir/function.cc
index 8a3d2d6947..14dd0eadb6 100644
--- a/src/tir/ir/function.cc
+++ b/src/tir/ir/function.cc
@@ -70,6 +70,10 @@ relax::StructInfo InferStructInfo(const PrimFunc& prim_func) 
{
 // Get the function type of a PrimFunc
 PrimFunc::PrimFunc(Array<tir::Var> params, Stmt body, Type ret_type,
                    Map<tir::Var, Buffer> buffer_map, DictAttrs attrs, Span 
span) {
+  if (!attrs.defined()) {
+    attrs = DictAttrs();
+  }
+
   // Assume void-return type for now
   // TODO(tvm-team) consider type deduction from body.
   if (!ret_type.defined()) {
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index c8db26c81b..e692768a12 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -2271,5 +2271,27 @@ def test_define_relax_function_using_global_var():
     tvm.ir.assert_structural_equal(DefinedAllAtOnce, MainDefinedLater)
 
 
+def test_function_attributes_are_defined():
+    """func.attrs defaults to an empty DictAttrs"""
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor, shape: R.Shape(["m", "n"])):
+            output = Module.subroutine(x, shape)
+            return output
+
+        @R.function
+        def subroutine(x: R.Tensor, _: R.Shape(["m", "n"])) -> R.Tensor(["m", 
"n"]):
+            q = x
+            m, n = T.int64(), T.int64()
+            z = R.match_cast(q, R.Tensor((m, n)))
+            w = z
+            return w
+
+    for gvar, func in Module.functions.items():
+        assert func.attrs is not None
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to