This is an automated email from the ASF dual-hosted git repository. wuwei pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 460f6f1d3e [QoL][Relax] Infer StructInfo for relax::Tuple on construction (#16860) 460f6f1d3e is described below commit 460f6f1d3e1625882df701252234350f83aa6da1 Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Tue Apr 16 16:28:00 2024 -0500 [QoL][Relax] Infer StructInfo for relax::Tuple on construction (#16860) Prior to this commit, the `relax::Tuple` constructor left the `struct_info_` field undefined. This is inconsistent with other Relax leaf nodes, such as `relax::PrimValue`, `relax::Constant`, and `relax::ExternFunc`, which initialize their struct info on construction. This commit updates the `relax::Tuple` constructor to define `struct_info_` as `TupleStructInfo`, if all fields have a known struct info. If any field does not have a known struct info, the current behavior is kept, where `struct_info_` is constructed as `NullOpt`, and is later populated by the `relax::BlockBuilder`. --- src/relax/ir/expr.cc | 16 ++++++++++++++++ tests/python/relax/test_expr.py | 19 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 0530bb770b..dd0f68dca4 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -137,9 +137,25 @@ TVM_REGISTER_GLOBAL("relax.If") }); Tuple::Tuple(tvm::Array<relay::Expr> fields, Span span) { + Optional<StructInfo> tuple_sinfo = [&]() -> Optional<StructInfo> { + Array<StructInfo> field_sinfo; + for (const auto& field : fields) { + if (field->struct_info_.defined()) { + field_sinfo.push_back(GetStructInfo(field)); + } else { + return NullOpt; + } + } + return TupleStructInfo(field_sinfo); + }(); + ObjectPtr<TupleNode> n = make_object<TupleNode>(); n->fields = std::move(fields); n->span = std::move(span); + if (tuple_sinfo) { + n->checked_type_ = GetStaticType(tuple_sinfo.value()); + } + n->struct_info_ = tuple_sinfo; data_ = std::move(n); } diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index af1bc851be..b20c9ef2d9 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -86,6 +86,25 @@ def test_tuple() -> None: t[-3] +def test_tuple_sinfo_inferred_on_construction(): + v0 = rx.Var("v0", rx.ObjectStructInfo()) + v1 = rx.Var("v1", rx.ObjectStructInfo()) + tup = rx.Tuple((v0, v1)) + + assert tup.struct_info_ is not None + tvm.ir.assert_structural_equal( + tup.struct_info, rx.TupleStructInfo([rx.ObjectStructInfo(), rx.ObjectStructInfo()]) + ) + + +def test_tuple_sinfo_requires_fields_with_known_sinfo(): + v0 = rx.Var("v0", rx.ObjectStructInfo()) + v1 = rx.Var("v1") + tup = rx.Tuple((v0, v1)) + + assert tup.struct_info_ is None + + def test_match_cast() -> None: # match_cast([16, 8], [m, n]) m = tir.Var("m", dtype="int64")