This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d129767d71 [Relax] Support constant folding for call_tir with tuple
outputs (#18736)
d129767d71 is described below
commit d129767d71f54e4fceefa9f7f812bc6db8f2876c
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Wed Feb 11 21:51:05 2026 +0800
[Relax] Support constant folding for call_tir with tuple outputs (#18736)
## Why
Constant folding skipped call_tir nodes with tuple (multi-tensor)
outputs, leaving foldable operations unoptimized.
## How
- Add ConstEvaluateCallTIRTuple to handle call_tir with TupleStructInfo
output by allocating and packing multiple output tensors
- Route VisitCallTIR through tuple vs single-tensor paths based on
sinfo_args type
- Add test for folding a split-like prim_func with two output tensors
Signed-off-by: Guan-Ming Chiu <[email protected]>
---
src/relax/transform/fold_constant.cc | 66 ++++++++++++++++++----
tests/python/relax/test_transform_fold_constant.py | 51 +++++++++++++++++
2 files changed, 107 insertions(+), 10 deletions(-)
diff --git a/src/relax/transform/fold_constant.cc
b/src/relax/transform/fold_constant.cc
index b714d49243..5a26d15850 100644
--- a/src/relax/transform/fold_constant.cc
+++ b/src/relax/transform/fold_constant.cc
@@ -47,11 +47,10 @@ class ConstantFolder : public ExprMutator {
* constant shape and get runtime shape tuple from it.
* \param struct_info The given struct info whose shape inside is to be
casted.
* \return The runtime shape tuple, or nullopt if it is not a constant shape.
- * \note Only TensorStructInfo is supported at this moment. Return
std::nullopt
+ * \note Only TensorStructInfo is supported. Returns std::nullopt
* if the input struct info is not TensorStructInfo.
*/
static ffi::Optional<ffi::Shape> MatchConstShape(const StructInfo&
struct_info) {
- // Only support single output for call_tir at this moment.
const auto* tensor_sinfo = struct_info.as<TensorStructInfoNode>();
if (tensor_sinfo == nullptr) {
return std::nullopt;
@@ -143,8 +142,8 @@ class ConstantFolder : public ExprMutator {
return true;
}
- // Try constant evaluate the function call
- // if failed return std::nullopt
+ // Try constant evaluate a call_tir with a single tensor output.
+ // Returns std::nullopt on failure.
ffi::Optional<Expr> ConstEvaluateCallTIR(tir::PrimFunc tir_func,
ffi::Array<runtime::Tensor>
arr_args, ffi::Shape shape,
DataType ret_type) {
@@ -175,25 +174,72 @@ class ConstantFolder : public ExprMutator {
return Constant(ret_tensor);
}
+ // Try constant evaluate a call_tir with tuple outputs (multiple output
tensors).
+ // Returns std::nullopt on failure.
+ ffi::Optional<Expr> ConstEvaluateCallTIRTuple(tir::PrimFunc tir_func,
+ ffi::Array<runtime::Tensor>
arr_args,
+ const TupleStructInfoNode*
tuple_sinfo) {
+ ffi::Optional<ffi::Function> func = GetCachedBuild(tir_func);
+ if (!func) return std::nullopt;
+
+ DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0};
+ size_t num_outputs = tuple_sinfo->fields.size();
+
+ // Match shapes and dtypes for all output fields.
+ std::vector<runtime::Tensor> ret_tensors;
+ for (size_t i = 0; i < num_outputs; ++i) {
+ ffi::Optional<ffi::Shape> shape =
MatchConstShape(tuple_sinfo->fields[i]);
+ if (!shape) return std::nullopt;
+ auto tensor_sinfo = Downcast<TensorStructInfo>(tuple_sinfo->fields[i]);
+ if (tensor_sinfo->IsUnknownDtype()) return std::nullopt;
+ ret_tensors.push_back(runtime::Tensor::Empty(shape.value(),
tensor_sinfo->dtype, cpu_dev));
+ }
+
+ // Pack input args + all output tensors.
+ std::vector<runtime::Tensor> temp_args(arr_args.begin(), arr_args.end());
+ std::vector<AnyView> packed_args;
+ packed_args.reserve(temp_args.size() + num_outputs);
+ for (const auto& arg : temp_args) {
+ packed_args.push_back(arg);
+ }
+ for (const auto& out_tensor : ret_tensors) {
+ packed_args.push_back(out_tensor);
+ }
+
+ ffi::Any ret;
+ func.value().CallPacked(ffi::PackedArgs(packed_args.data(),
packed_args.size()), &ret);
+
+ ffi::Array<Expr> fields;
+ for (size_t i = 0; i < num_outputs; ++i) {
+ fields.push_back(Constant(ret_tensors[i]));
+ }
+ return Tuple(fields);
+ }
+
// Returns the folded expr if the call is successfully folded to constant,
otherwise null.
ffi::Optional<Expr> VisitCallTIR(Call call) {
- // call_tir needs to have at least three arguments
+ // call_tir needs to have at least two arguments
ICHECK_GE(call->args.size(), 2);
ffi::Optional<tir::PrimFunc> func = MatchPrimFunc(call->args[0]);
ICHECK(call->args[1].as<TupleNode>()) << "call_tir.args[1] must be Tuple";
ffi::Optional<ffi::Array<runtime::Tensor>> arr_args =
MatchConstArrayArgs(call->args[1].as<TupleNode>()->fields);
ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one
sinfo arg";
+
+ if (!func || !arr_args) return {};
+
+ // Handle tuple output: sinfo_args[0] is a TupleStructInfo.
+ if (const auto* tuple_sinfo =
call->sinfo_args[0].as<TupleStructInfoNode>()) {
+ return ConstEvaluateCallTIRTuple(func.value(), arr_args.value(),
tuple_sinfo);
+ }
+
+ // Handle single tensor output.
ffi::Optional<ffi::Shape> shape = MatchConstShape(call->sinfo_args[0]);
- bool output_not_tuple = call->sinfo_args.size() == 1;
- // Pattern 0: call constant function, const argument with const shape.
- if (func && arr_args && shape && output_not_tuple) {
+ if (shape) {
TensorStructInfo ret_sinfo =
Downcast<TensorStructInfo>(call->struct_info_);
- // value_or will return value if it is not null, otherwise return or
return ConstEvaluateCallTIR(func.value(), arr_args.value(),
shape.value(), ret_sinfo->dtype)
.value_or({});
}
- // TODO(hongyi): support const-fold tuple outputs
return {};
}
diff --git a/tests/python/relax/test_transform_fold_constant.py
b/tests/python/relax/test_transform_fold_constant.py
index e3eb7f6367..92125bc351 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -442,5 +442,56 @@ def test_fold_shape_computation():
tvm.ir.assert_structural_equal(after, expected)
+def test_fold_tuple_output():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def split(
+ A: T.Buffer((4, 4), "float32"),
+ B: T.Buffer((2, 4), "float32"),
+ C: T.Buffer((2, 4), "float32"),
+ ) -> None:
+ for i, j in T.grid(2, 4):
+ with T.sblock("upper"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj]
+ for i, j in T.grid(2, 4):
+ with T.sblock("lower"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = A[vi + 2, vj]
+
+ @R.function
+ def before(c0: R.Tensor((4, 4), "float32")):
+ cls = Module
+ lv0 = relax.call_tir(
+ cls.split,
+ (c0,),
+ out_sinfo=[
+ R.Tensor((2, 4), dtype="float32"),
+ R.Tensor((2, 4), dtype="float32"),
+ ],
+ )
+ return lv0
+
+ @R.function
+ def expected(
+ c1: R.Tensor((2, 4), "float32"), c2: R.Tensor((2, 4), "float32")
+ ) -> R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4),
dtype="float32")):
+ lv0: R.Tuple(R.Tensor((2, 4), dtype="float32"), R.Tensor((2, 4),
dtype="float32")) = (
+ c1,
+ c2,
+ )
+ return lv0
+
+ c0_np = np.arange(16).astype("float32").reshape(4, 4)
+ c1_np = c0_np[:2]
+ c2_np = c0_np[2:]
+ before = gen_mod(Module, "before", {"c0": c0_np})
+ expected = gen_mod(Module, "expected", {"c1": c1_np, "c2": c2_np})
+
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()