This is an automated email from the ASF dual-hosted git repository.
guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 71ac3e821b [Relax][NN] Add batch_flatten operator (#18677)
71ac3e821b is described below
commit 71ac3e821b750c29cbde0d676d0615777d13f95b
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Thu Jan 29 13:07:25 2026 +0800
[Relax][NN] Add batch_flatten operator (#18677)
## Why
The `batch_flatten` operator was needed for CoreML backend support but
was not implemented in Relax, causing the CoreML pattern and converter
to be commented out.
## How
- Add C++ operator `relax.nn.batch_flatten` with struct info inference
- Add Python wrapper in `relax.op.nn`
- Add legalization via `topi.reshape`
- Enable CoreML pattern and converter for `nn.batch_flatten`
- Add unit tests for operator, struct info inference, and legalization
---
python/tvm/relax/backend/metal/coreml.py | 5 +-
python/tvm/relax/op/nn/__init__.py | 1 +
python/tvm/relax/op/nn/nn.py | 19 ++++++++
python/tvm/relax/transform/legalize_ops/nn.py | 7 +++
src/relax/op/nn/nn.cc | 50 ++++++++++++++++++++
src/relax/op/nn/nn.h | 3 ++
tests/python/relax/test_codegen_coreml.py | 1 -
tests/python/relax/test_op_nn.py | 54 ++++++++++++++++++++++
.../python/relax/test_transform_legalize_ops_nn.py | 43 +++++++++++++++++
9 files changed, 179 insertions(+), 4 deletions(-)
diff --git a/python/tvm/relax/backend/metal/coreml.py
b/python/tvm/relax/backend/metal/coreml.py
index dfc891dc1f..c7d922b599 100644
--- a/python/tvm/relax/backend/metal/coreml.py
+++ b/python/tvm/relax/backend/metal/coreml.py
@@ -142,11 +142,10 @@ register_patterns(
*default_unary_patterns(op_name="nn.relu"),
*default_unary_patterns(op_name="expand_dims"),
*default_unary_patterns(op_name="nn.avg_pool2d"),
+ *default_unary_patterns(op_name="nn.batch_flatten"),
*conv2d_patterns(),
*clip_patterns(),
*matmul_patterns(),
- # TODO(@tvm-team): enable when relax op is implemented
- # ("coreml.nn.batch_flatten",
is_op("relax.nn.batch_flatten")(wildcard())),
]
)
@@ -271,7 +270,7 @@ _convert_map = {
"clip": _convert_clip,
"expand_dims": _convert_expand_dims,
"nn.relu": _convert_relu,
- # "nn.batch_flatten": _convert_batch_flatten,
+ "nn.batch_flatten": _convert_batch_flatten,
"nn.softmax": _convert_softmax,
"nn.conv2d": _convert_conv2d,
"nn.avg_pool2d": _convert_avg_pool2d,
diff --git a/python/tvm/relax/op/nn/__init__.py
b/python/tvm/relax/op/nn/__init__.py
index ec1135ef2c..cf804440f1 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -25,6 +25,7 @@ from .nn import (
avg_pool1d,
avg_pool2d,
avg_pool3d,
+ batch_flatten,
batch_norm,
conv1d,
conv1d_transpose,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index c7be2a7ba6..e9710deca9 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -2249,3 +2249,22 @@ def attention_var_len(
causal_mask,
window_size,
) # type: ignore
+
+
+def batch_flatten(data: Expr) -> Expr:
+ """Flatten all dimensions except the first (batch) dimension.
+
+ This operation flattens a tensor of shape `(N, C, H, W, ...)` into
+ a 2D tensor of shape `(N, C*H*W*...)`.
+
+ Parameters
+ ----------
+ data : relax.Expr
+ The input data to the operator.
+
+ Returns
+ -------
+ result : relax.Expr
+ The flattened result with shape `(batch_size, flattened_features)`.
+ """
+ return _ffi_api.batch_flatten(data) # type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index ed9802fc9e..1a0477af20 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -775,3 +775,10 @@ def _nn_nll_loss(bb: BlockBuilder, call: Call) -> Expr:
reduction=call.attrs.reduction,
ignore_index=call.attrs.ignore_index,
)
+
+
+@register_legalize("relax.nn.batch_flatten")
+def _nn_batch_flatten(bb: BlockBuilder, call: Call) -> Expr:
+ if call.struct_info.shape is None:
+ return call
+ return bb.call_te(topi.reshape, call.args[0],
call.struct_info.shape.values)
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index f4b9fe400b..0a23358343 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -1192,5 +1192,55 @@ TVM_REGISTER_OP("relax.nn.nll_loss")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoNLLLoss)
.set_attr<Bool>("FPurity", Bool(true));
+/* relax.nn.batch_flatten */
+
+Expr batch_flatten(Expr data) {
+ static const Op& op = Op::Get("relax.nn.batch_flatten");
+ return Call(op, {std::move(data)}, {}, {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("relax.op.nn.batch_flatten", batch_flatten);
+}
+
+StructInfo InferStructInfoBatchFlatten(const Call& call, const BlockBuilder&
ctx) {
+ TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+
+ if (data_sinfo->IsUnknownNdim()) {
+ return TensorStructInfo(data_sinfo->dtype, /*ndim=*/2,
data_sinfo->vdevice);
+ }
+
+ if (data_sinfo->ndim < 2) {
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "batch_flatten expects input tensor to have at least 2
dimensions, "
+ << "but got " << data_sinfo->ndim);
+ }
+
+ if (data_sinfo->ndim == 2) {
+ return data_sinfo;
+ }
+
+ const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+ if (data_shape == nullptr) {
+ return TensorStructInfo(data_sinfo->dtype, /*ndim=*/2,
data_sinfo->vdevice);
+ }
+
+ PrimExpr batch_dim = data_shape->values[0];
+ PrimExpr flat_dim = IntImm(DataType::Int(64), 1);
+ for (size_t i = 1; i < data_shape->values.size(); ++i) {
+ flat_dim = flat_dim * data_shape->values[i];
+ }
+
+ return TensorStructInfo(ShapeExpr({batch_dim, flat_dim}), data_sinfo->dtype,
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.nn.batch_flatten")
+ .set_num_inputs(1)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoBatchFlatten)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kFollow)
+ .set_attr<Bool>("FPurity", Bool(true));
+
} // namespace relax
} // namespace tvm
diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h
index 989dfbb3f6..b6f749854f 100644
--- a/src/relax/op/nn/nn.h
+++ b/src/relax/op/nn/nn.h
@@ -114,6 +114,9 @@ Expr cross_entropy_with_logits(Expr predictions, Expr
labels);
Expr nll_loss(Expr predictions, Expr targets, ffi::Optional<Expr> weights,
ffi::String reduction,
int ignore_index);
+/*! \brief Batch flatten: flatten all dimensions except the first (batch)
dimension. */
+Expr batch_flatten(Expr data);
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_codegen_coreml.py
b/tests/python/relax/test_codegen_coreml.py
index b07271e894..de3a6d0789 100644
--- a/tests/python/relax/test_codegen_coreml.py
+++ b/tests/python/relax/test_codegen_coreml.py
@@ -198,7 +198,6 @@ def test_relu():
verify(mod, [x_data])
[email protected]("`batch_flatten` is not implemented yet.")
def test_batch_flatten():
x = relax.Var("x", relax.TensorStructInfo([10, 10, 10], "float32"))
bb = relax.BlockBuilder()
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index b076827dc4..4c419ed0e1 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -1845,5 +1845,59 @@ def test_pixel_shuffle_infer_struct_info():
)
+def test_batch_flatten_op_correctness():
+ x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+ assert relax.op.nn.batch_flatten(x).op == Op.get("relax.nn.batch_flatten")
+
+
+def test_batch_flatten_infer_struct_info():
+ bb = relax.BlockBuilder()
+ vdev0 = VDevice("llvm")
+ x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+ x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+ x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
+ x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
+ x4 = relax.Var("x", R.Tensor((10, 20), "float32"))
+ x5 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))
+
+ _check_inference(bb, relax.op.nn.batch_flatten(x0),
relax.TensorStructInfo((2, 60), "float32"))
+ _check_inference(
+ bb, relax.op.nn.batch_flatten(x5), relax.TensorStructInfo((2, 60),
"float32", vdev0)
+ )
+ _check_inference(
+ bb, relax.op.nn.batch_flatten(x1),
relax.TensorStructInfo(dtype="float32", ndim=2)
+ )
+ _check_inference(
+ bb, relax.op.nn.batch_flatten(x2),
relax.TensorStructInfo(dtype="float32", ndim=2)
+ )
+ _check_inference(bb, relax.op.nn.batch_flatten(x3),
relax.TensorStructInfo((2, 60), dtype=""))
+ _check_inference(bb, relax.op.nn.batch_flatten(x4),
relax.TensorStructInfo((10, 20), "float32"))
+
+
+def test_batch_flatten_infer_struct_info_shape_symbolic():
+ bb = relax.BlockBuilder()
+ m = tir.Var("m", "int64")
+ n = tir.Var("n", "int64")
+ h = tir.Var("h", "int64")
+ w = tir.Var("w", "int64")
+ x0 = relax.Var("x", R.Tensor((m, n, h, w), "float32"))
+ x1 = relax.Var("x", R.Tensor((4, n, 8, 8), "float32"))
+
+ _check_inference(
+ bb, relax.op.nn.batch_flatten(x0), relax.TensorStructInfo((m, n * h *
w), "float32")
+ )
+ _check_inference(
+ bb, relax.op.nn.batch_flatten(x1), relax.TensorStructInfo((4, n * 8 *
8), "float32")
+ )
+
+
+def test_batch_flatten_infer_struct_info_wrong_ndim():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((3,), "float32"))
+
+ with pytest.raises(TVMError):
+ bb.normalize(relax.op.nn.batch_flatten(x0))
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 5d3390b810..3121c31ac8 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -3898,5 +3898,48 @@ def test_pad():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_batch_flatten():
+ # fmt: off
+ @tvm.script.ir_module
+ class BatchFlatten:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 60),
"float32"):
+ gv: R.Tensor((2, 60), "float32") = R.nn.batch_flatten(x)
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((2,
60), dtype="float32"):
+ gv = R.call_tir(Expected.reshape, (x,), out_sinfo=R.Tensor((2,
60), dtype="float32"))
+ return gv
+
+ @T.prim_func(private=True)
+ def reshape(x: T.Buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(60)),
"float32")):
+ T.func_attr({"tir.noalias": True})
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(60)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(x[(v_ax1 // T.int64(60) + v_ax0) % T.int64(2),
v_ax1 % T.int64(60) // T.int64(20), v_ax1 % T.int64(20) // T.int64(5), v_ax1 %
T.int64(5)])
+ T.writes(T_reshape[v_ax0, v_ax1])
+ T_reshape[v_ax0, v_ax1] = x[(v_ax1 // T.int64(60) + v_ax0)
% T.int64(2), v_ax1 % T.int64(60) // T.int64(20), v_ax1 % T.int64(20) //
T.int64(5), v_ax1 % T.int64(5)]
+ # fmt: on
+
+ mod = LegalizeOps()(BatchFlatten)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_batch_flatten_undefined_shape():
+ @tvm.script.ir_module
+ class BatchFlattenUndefinedShape:
+ @R.function
+ def main(x: R.Tensor(ndim=4, dtype="float32")) -> R.Tensor(ndim=2,
dtype="float32"):
+ gv: R.Tensor(ndim=2, dtype="float32") = R.nn.batch_flatten(x)
+ return gv
+
+ mod = LegalizeOps()(BatchFlattenUndefinedShape)
+ tvm.ir.assert_structural_equal(mod, BatchFlattenUndefinedShape)
+
+
if __name__ == "__main__":
tvm.testing.main()