This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 460f6f1d3e [QoL][Relax] Infer StructInfo for relax::Tuple on 
construction (#16860)
460f6f1d3e is described below

commit 460f6f1d3e1625882df701252234350f83aa6da1
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Tue Apr 16 16:28:00 2024 -0500

    [QoL][Relax] Infer StructInfo for relax::Tuple on construction (#16860)
    
    Prior to this commit, the `relax::Tuple` constructor left the
    `struct_info_` field undefined.  This is inconsistent with other Relax
    leaf nodes, such as `relax::PrimValue`, `relax::Constant`, and
    `relax::ExternFunc`, which initialize their struct info on
    construction.
    
    This commit updates the `relax::Tuple` constructor to define
    `struct_info_` as `TupleStructInfo`, if all fields have a known struct
    info.  If any field does not have a known struct info, the current
    behavior is kept, where `struct_info_` is constructed as `NullOpt`,
    and is later populated by the `relax::BlockBuilder`.
---
 src/relax/ir/expr.cc            | 16 ++++++++++++++++
 tests/python/relax/test_expr.py | 19 +++++++++++++++++++
 2 files changed, 35 insertions(+)

diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index 0530bb770b..dd0f68dca4 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -137,9 +137,25 @@ TVM_REGISTER_GLOBAL("relax.If")
     });
 
 Tuple::Tuple(tvm::Array<relay::Expr> fields, Span span) {
+  Optional<StructInfo> tuple_sinfo = [&]() -> Optional<StructInfo> {
+    Array<StructInfo> field_sinfo;
+    for (const auto& field : fields) {
+      if (field->struct_info_.defined()) {
+        field_sinfo.push_back(GetStructInfo(field));
+      } else {
+        return NullOpt;
+      }
+    }
+    return TupleStructInfo(field_sinfo);
+  }();
+
   ObjectPtr<TupleNode> n = make_object<TupleNode>();
   n->fields = std::move(fields);
   n->span = std::move(span);
+  if (tuple_sinfo) {
+    n->checked_type_ = GetStaticType(tuple_sinfo.value());
+  }
+  n->struct_info_ = tuple_sinfo;
   data_ = std::move(n);
 }
 
diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py
index af1bc851be..b20c9ef2d9 100644
--- a/tests/python/relax/test_expr.py
+++ b/tests/python/relax/test_expr.py
@@ -86,6 +86,25 @@ def test_tuple() -> None:
         t[-3]
 
 
+def test_tuple_sinfo_inferred_on_construction():
+    v0 = rx.Var("v0", rx.ObjectStructInfo())
+    v1 = rx.Var("v1", rx.ObjectStructInfo())
+    tup = rx.Tuple((v0, v1))
+
+    assert tup.struct_info_ is not None
+    tvm.ir.assert_structural_equal(
+        tup.struct_info, rx.TupleStructInfo([rx.ObjectStructInfo(), 
rx.ObjectStructInfo()])
+    )
+
+
+def test_tuple_sinfo_requires_fields_with_known_sinfo():
+    v0 = rx.Var("v0", rx.ObjectStructInfo())
+    v1 = rx.Var("v1")
+    tup = rx.Tuple((v0, v1))
+
+    assert tup.struct_info_ is None
+
+
 def test_match_cast() -> None:
     # match_cast([16, 8], [m, n])
     m = tir.Var("m", dtype="int64")

Reply via email to