This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new d129767d71 [Relax] Support constant folding for call_tir with tuple 
outputs (#18736)
d129767d71 is described below

commit d129767d71f54e4fceefa9f7f812bc6db8f2876c
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Feb 11 21:51:05 2026 +0800

    [Relax] Support constant folding for call_tir with tuple outputs (#18736)
    
    ## Why
    
    Constant folding skipped call_tir nodes with tuple (multi-tensor)
    outputs, leaving foldable operations unoptimized.
    
    ## How
    
    - Add ConstEvaluateCallTIRTuple to handle call_tir with TupleStructInfo
    output by allocating and packing multiple output tensors
    - Route VisitCallTIR through tuple vs single-tensor paths based on
    sinfo_args type
    - Add test for folding a split-like prim_func with two output tensors
    
    Signed-off-by: Guan-Ming Chiu <[email protected]>
---
 src/relax/transform/fold_constant.cc               | 66 ++++++++++++++++++----
 tests/python/relax/test_transform_fold_constant.py | 51 +++++++++++++++++
 2 files changed, 107 insertions(+), 10 deletions(-)

diff --git a/src/relax/transform/fold_constant.cc 
b/src/relax/transform/fold_constant.cc
index b714d49243..5a26d15850 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -47,11 +47,10 @@ class ConstantFolder : public ExprMutator {
    * constant shape and get runtime shape tuple from it.
    * \param struct_info The given struct info whose shape inside is to be 
casted.
    * \return The runtime shape tuple, or nullopt if it is not a constant shape.
-   * \note Only TensorStructInfo is supported at this moment. Return 
std::nullopt
+   * \note Only TensorStructInfo is supported. Returns std::nullopt
    * if the input struct info is not TensorStructInfo.
    */
   static ffi::Optional<ffi::Shape> MatchConstShape(const StructInfo& 
struct_info) {
-    // Only support single output for call_tir at this moment.
     const auto* tensor_sinfo = struct_info.as<TensorStructInfoNode>();
     if (tensor_sinfo == nullptr) {
       return std::nullopt;
@@ -143,8 +142,8 @@ class ConstantFolder : public ExprMutator {
     return true;
   }
 
-  // Try constant evaluate the function call
-  // if failed return std::nullopt
+  // Try constant evaluate a call_tir with a single tensor output.
+  // Returns std::nullopt on failure.
   ffi::Optional<Expr> ConstEvaluateCallTIR(tir::PrimFunc tir_func,
                                            ffi::Array<runtime::Tensor> 
arr_args, ffi::Shape shape,
                                            DataType ret_type) {
@@ -175,25 +174,72 @@ class ConstantFolder : public ExprMutator {
     return Constant(ret_tensor);
   }
 
+  // Try constant evaluate a call_tir with tuple outputs (multiple output 
tensors).
+  // Returns std::nullopt on failure.
+  ffi::Optional<Expr> ConstEvaluateCallTIRTuple(tir::PrimFunc tir_func,
+                                                ffi::Array<runtime::Tensor> 
arr_args,
+                                                const TupleStructInfoNode* 
tuple_sinfo) {
+    ffi::Optional<ffi::Function> func = GetCachedBuild(tir_func);
+    if (!func) return std::nullopt;
+
+    DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0};
+    size_t num_outputs = tuple_sinfo->fields.size();
+
+    // Match shapes and dtypes for all output fields.
+    std::vector<runtime::Tensor> ret_tensors;
+    for (size_t i = 0; i < num_outputs; ++i) {
+      ffi::Optional<ffi::Shape> shape = 
MatchConstShape(tuple_sinfo->fields[i]);
+      if (!shape) return std::nullopt;
+      auto tensor_sinfo = Downcast<TensorStructInfo>(tuple_sinfo->fields[i]);
+      if (tensor_sinfo->IsUnknownDtype()) return std::nullopt;
+      ret_tensors.push_back(runtime::Tensor::Empty(shape.value(), 
tensor_sinfo->dtype, cpu_dev));
+    }
+
+    // Pack input args + all output tensors.
+    std::vector<runtime::Tensor> temp_args(arr_args.begin(), arr_args.end());
+    std::vector<AnyView> packed_args;
+    packed_args.reserve(temp_args.size() + num_outputs);
+    for (const auto& arg : temp_args) {
+      packed_args.push_back(arg);
+    }
+    for (const auto& out_tensor : ret_tensors) {
+      packed_args.push_back(out_tensor);
+    }
+
+    ffi::Any ret;
+    func.value().CallPacked(ffi::PackedArgs(packed_args.data(), 
packed_args.size()), &ret);
+
+    ffi::Array<Expr> fields;
+    for (size_t i = 0; i < num_outputs; ++i) {
+      fields.push_back(Constant(ret_tensors[i]));
+    }
+    return Tuple(fields);
+  }
+
   // Returns the folded expr if the call is successfully folded to constant, 
otherwise null.
   ffi::Optional<Expr> VisitCallTIR(Call call) {
-    // call_tir needs to have at least three arguments
+    // call_tir needs to have at least two arguments
     ICHECK_GE(call->args.size(), 2);
     ffi::Optional<tir::PrimFunc> func = MatchPrimFunc(call->args[0]);
     ICHECK(call->args[1].as<TupleNode>()) << "call_tir.args[1] must be Tuple";
     ffi::Optional<ffi::Array<runtime::Tensor>> arr_args =
         MatchConstArrayArgs(call->args[1].as<TupleNode>()->fields);
     ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one 
sinfo arg";
+
+    if (!func || !arr_args) return {};
+
+    // Handle tuple output: sinfo_args[0] is a TupleStructInfo.
+    if (const auto* tuple_sinfo = 
call->sinfo_args[0].as<TupleStructInfoNode>()) {
+      return ConstEvaluateCallTIRTuple(func.value(), arr_args.value(), 
tuple_sinfo);
+    }
+
+    // Handle single tensor output.
     ffi::Optional<ffi::Shape> shape = MatchConstShape(call->sinfo_args[0]);
-    bool output_not_tuple = call->sinfo_args.size() == 1;
-    // Pattern 0: call constant function, const argument with const shape.
-    if (func && arr_args && shape && output_not_tuple) {
+    if (shape) {
       TensorStructInfo ret_sinfo = 
Downcast<TensorStructInfo>(call->struct_info_);
-      // value_or will return value if it is not null, otherwise return or
       return ConstEvaluateCallTIR(func.value(), arr_args.value(), 
shape.value(), ret_sinfo->dtype)
           .value_or({});
     }
-    // TODO(hongyi): support const-fold tuple outputs
     return {};
   }
 
diff --git a/tests/python/relax/test_transform_fold_constant.py 
b/tests/python/relax/test_transform_fold_constant.py
index e3eb7f6367..92125bc351 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -442,5 +442,56 @@ def test_fold_shape_computation():
     tvm.ir.assert_structural_equal(after, expected)
 
 
+def test_fold_tuple_output():
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def split(
+            A: T.Buffer((4, 4), "float32"),
+            B: T.Buffer((2, 4), "float32"),
+            C: T.Buffer((2, 4), "float32"),
+        ) -> None:
+            for i, j in T.grid(2, 4):
+                with T.sblock("upper"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    B[vi, vj] = A[vi, vj]
+            for i, j in T.grid(2, 4):
+                with T.sblock("lower"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    C[vi, vj] = A[vi + 2, vj]
+
+        @R.function
+        def before(c0: R.Tensor((4, 4), "float32")):
+            cls = Module
+            lv0 = relax.call_tir(
+                cls.split,
+                (c0,),
+                out_sinfo=[
+                    R.Tensor((2, 4), dtype="float32"),
+                    R.Tensor((2, 4), dtype="float32"),
+                ],
+            )
+            return lv0
+
+        @R.function
+        def expected(
+            c1: R.Tensor((2, 4), "float32"), c2: R.Tensor((2, 4), "float32")
+        ) -> R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), 
dtype="float32")):
+            lv0: R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4), 
dtype="float32")) = (
+                c1,
+                c2,
+            )
+            return lv0
+
+    c0_np = np.arange(16).astype("float32").reshape(4, 4)
+    c1_np = c0_np[:2]
+    c2_np = c0_np[2:]
+    before = gen_mod(Module, "before", {"c0": c0_np})
+    expected = gen_mod(Module, "expected", {"c1": c1_np, "c2": c2_np})
+
+    after = relax.transform.FoldConstant()(before)
+    tvm.ir.assert_structural_equal(after, expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to