This is an automated email from the ASF dual-hosted git repository.

guanmingchiu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 71ac3e821b [Relax][NN] Add batch_flatten operator (#18677)
71ac3e821b is described below

commit 71ac3e821b750c29cbde0d676d0615777d13f95b
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Thu Jan 29 13:07:25 2026 +0800

    [Relax][NN] Add batch_flatten operator (#18677)
    
    ## Why
    The `batch_flatten` operator was needed for CoreML backend support but
    was not implemented in Relax, causing the CoreML pattern and converter
    to be commented out.
    ## How
    - Add C++ operator `relax.nn.batch_flatten` with struct info inference
    - Add Python wrapper in `relax.op.nn`
    - Add legalization via `topi.reshape`
    - Enable CoreML pattern and converter for `nn.batch_flatten`
    - Add unit tests for operator, struct info inference, and legalization
---
 python/tvm/relax/backend/metal/coreml.py           |  5 +-
 python/tvm/relax/op/nn/__init__.py                 |  1 +
 python/tvm/relax/op/nn/nn.py                       | 19 ++++++++
 python/tvm/relax/transform/legalize_ops/nn.py      |  7 +++
 src/relax/op/nn/nn.cc                              | 50 ++++++++++++++++++++
 src/relax/op/nn/nn.h                               |  3 ++
 tests/python/relax/test_codegen_coreml.py          |  1 -
 tests/python/relax/test_op_nn.py                   | 54 ++++++++++++++++++++++
 .../python/relax/test_transform_legalize_ops_nn.py | 43 +++++++++++++++++
 9 files changed, 179 insertions(+), 4 deletions(-)

diff --git a/python/tvm/relax/backend/metal/coreml.py 
b/python/tvm/relax/backend/metal/coreml.py
index dfc891dc1f..c7d922b599 100644
--- a/python/tvm/relax/backend/metal/coreml.py
+++ b/python/tvm/relax/backend/metal/coreml.py
@@ -142,11 +142,10 @@ register_patterns(
         *default_unary_patterns(op_name="nn.relu"),
         *default_unary_patterns(op_name="expand_dims"),
         *default_unary_patterns(op_name="nn.avg_pool2d"),
+        *default_unary_patterns(op_name="nn.batch_flatten"),
         *conv2d_patterns(),
         *clip_patterns(),
         *matmul_patterns(),
-        # TODO(@tvm-team): enable when relax op is implemented
-        # ("coreml.nn.batch_flatten", 
is_op("relax.nn.batch_flatten")(wildcard())),
     ]
 )
 
@@ -271,7 +270,7 @@ _convert_map = {
     "clip": _convert_clip,
     "expand_dims": _convert_expand_dims,
     "nn.relu": _convert_relu,
-    # "nn.batch_flatten": _convert_batch_flatten,
+    "nn.batch_flatten": _convert_batch_flatten,
     "nn.softmax": _convert_softmax,
     "nn.conv2d": _convert_conv2d,
     "nn.avg_pool2d": _convert_avg_pool2d,
diff --git a/python/tvm/relax/op/nn/__init__.py 
b/python/tvm/relax/op/nn/__init__.py
index ec1135ef2c..cf804440f1 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -25,6 +25,7 @@ from .nn import (
     avg_pool1d,
     avg_pool2d,
     avg_pool3d,
+    batch_flatten,
     batch_norm,
     conv1d,
     conv1d_transpose,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index c7be2a7ba6..e9710deca9 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -2249,3 +2249,22 @@ def attention_var_len(
         causal_mask,
         window_size,
     )  # type: ignore
+
+
+def batch_flatten(data: Expr) -> Expr:
+    """Flatten all dimensions except the first (batch) dimension.
+
+    This operation flattens a tensor of shape `(N, C, H, W, ...)` into
+    a 2D tensor of shape `(N, C*H*W*...)`.
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    Returns
+    -------
+    result : relax.Expr
+        The flattened result with shape `(batch_size, flattened_features)`.
+    """
+    return _ffi_api.batch_flatten(data)  # type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index ed9802fc9e..1a0477af20 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -775,3 +775,10 @@ def _nn_nll_loss(bb: BlockBuilder, call: Call) -> Expr:
         reduction=call.attrs.reduction,
         ignore_index=call.attrs.ignore_index,
     )
+
+
+@register_legalize("relax.nn.batch_flatten")
+def _nn_batch_flatten(bb: BlockBuilder, call: Call) -> Expr:
+    if call.struct_info.shape is None:
+        return call
+    return bb.call_te(topi.reshape, call.args[0], 
call.struct_info.shape.values)
diff --git a/src/relax/op/nn/nn.cc b/src/relax/op/nn/nn.cc
index f4b9fe400b..0a23358343 100644
--- a/src/relax/op/nn/nn.cc
+++ b/src/relax/op/nn/nn.cc
@@ -1192,5 +1192,55 @@ TVM_REGISTER_OP("relax.nn.nll_loss")
     .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoNLLLoss)
     .set_attr<Bool>("FPurity", Bool(true));
 
+/* relax.nn.batch_flatten */
+
+Expr batch_flatten(Expr data) {
+  static const Op& op = Op::Get("relax.nn.batch_flatten");
+  return Call(op, {std::move(data)}, {}, {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.nn.batch_flatten", batch_flatten);
+}
+
+StructInfo InferStructInfoBatchFlatten(const Call& call, const BlockBuilder& 
ctx) {
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+
+  if (data_sinfo->IsUnknownNdim()) {
+    return TensorStructInfo(data_sinfo->dtype, /*ndim=*/2, 
data_sinfo->vdevice);
+  }
+
+  if (data_sinfo->ndim < 2) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "batch_flatten expects input tensor to have at least 2 
dimensions, "
+                     << "but got " << data_sinfo->ndim);
+  }
+
+  if (data_sinfo->ndim == 2) {
+    return data_sinfo;
+  }
+
+  const auto* data_shape = data_sinfo->shape.as<ShapeExprNode>();
+  if (data_shape == nullptr) {
+    return TensorStructInfo(data_sinfo->dtype, /*ndim=*/2, 
data_sinfo->vdevice);
+  }
+
+  PrimExpr batch_dim = data_shape->values[0];
+  PrimExpr flat_dim = IntImm(DataType::Int(64), 1);
+  for (size_t i = 1; i < data_shape->values.size(); ++i) {
+    flat_dim = flat_dim * data_shape->values[i];
+  }
+
+  return TensorStructInfo(ShapeExpr({batch_dim, flat_dim}), data_sinfo->dtype, 
data_sinfo->vdevice);
+}
+
+TVM_REGISTER_OP("relax.nn.batch_flatten")
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoBatchFlatten)
+    .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", 
MixedPrecisionPolicyKind::kFollow)
+    .set_attr<Bool>("FPurity", Bool(true));
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/op/nn/nn.h b/src/relax/op/nn/nn.h
index 989dfbb3f6..b6f749854f 100644
--- a/src/relax/op/nn/nn.h
+++ b/src/relax/op/nn/nn.h
@@ -114,6 +114,9 @@ Expr cross_entropy_with_logits(Expr predictions, Expr 
labels);
 Expr nll_loss(Expr predictions, Expr targets, ffi::Optional<Expr> weights, 
ffi::String reduction,
               int ignore_index);
 
+/*! \brief Batch flatten: flatten all dimensions except the first (batch) 
dimension. */
+Expr batch_flatten(Expr data);
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/tests/python/relax/test_codegen_coreml.py 
b/tests/python/relax/test_codegen_coreml.py
index b07271e894..de3a6d0789 100644
--- a/tests/python/relax/test_codegen_coreml.py
+++ b/tests/python/relax/test_codegen_coreml.py
@@ -198,7 +198,6 @@ def test_relu():
     verify(mod, [x_data])
 
 
[email protected]("`batch_flatten` is not implemented yet.")
 def test_batch_flatten():
     x = relax.Var("x", relax.TensorStructInfo([10, 10, 10], "float32"))
     bb = relax.BlockBuilder()
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index b076827dc4..4c419ed0e1 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -1845,5 +1845,59 @@ def test_pixel_shuffle_infer_struct_info():
     )
 
 
+def test_batch_flatten_op_correctness():
+    x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    assert relax.op.nn.batch_flatten(x).op == Op.get("relax.nn.batch_flatten")
+
+
+def test_batch_flatten_infer_struct_info():
+    bb = relax.BlockBuilder()
+    vdev0 = VDevice("llvm")
+    x0 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=-1))
+    x3 = relax.Var("x", R.Tensor((2, 3, 4, 5)))
+    x4 = relax.Var("x", R.Tensor((10, 20), "float32"))
+    x5 = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32", vdev0))
+
+    _check_inference(bb, relax.op.nn.batch_flatten(x0), 
relax.TensorStructInfo((2, 60), "float32"))
+    _check_inference(
+        bb, relax.op.nn.batch_flatten(x5), relax.TensorStructInfo((2, 60), 
"float32", vdev0)
+    )
+    _check_inference(
+        bb, relax.op.nn.batch_flatten(x1), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(
+        bb, relax.op.nn.batch_flatten(x2), 
relax.TensorStructInfo(dtype="float32", ndim=2)
+    )
+    _check_inference(bb, relax.op.nn.batch_flatten(x3), 
relax.TensorStructInfo((2, 60), dtype=""))
+    _check_inference(bb, relax.op.nn.batch_flatten(x4), 
relax.TensorStructInfo((10, 20), "float32"))
+
+
+def test_batch_flatten_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    h = tir.Var("h", "int64")
+    w = tir.Var("w", "int64")
+    x0 = relax.Var("x", R.Tensor((m, n, h, w), "float32"))
+    x1 = relax.Var("x", R.Tensor((4, n, 8, 8), "float32"))
+
+    _check_inference(
+        bb, relax.op.nn.batch_flatten(x0), relax.TensorStructInfo((m, n * h * 
w), "float32")
+    )
+    _check_inference(
+        bb, relax.op.nn.batch_flatten(x1), relax.TensorStructInfo((4, n * 8 * 
8), "float32")
+    )
+
+
+def test_batch_flatten_infer_struct_info_wrong_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((3,), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.nn.batch_flatten(x0))
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 5d3390b810..3121c31ac8 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -3898,5 +3898,48 @@ def test_pad():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_batch_flatten():
+    # fmt: off
+    @tvm.script.ir_module
+    class BatchFlatten:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4, 5), "float32")) -> R.Tensor((2, 60), 
"float32"):
+            gv: R.Tensor((2, 60), "float32") = R.nn.batch_flatten(x)
+            return gv
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((2, 3, 4, 5), dtype="float32")) -> R.Tensor((2, 
60), dtype="float32"):
+            gv = R.call_tir(Expected.reshape, (x,), out_sinfo=R.Tensor((2, 
60), dtype="float32"))
+            return gv
+
+        @T.prim_func(private=True)
+        def reshape(x: T.Buffer((T.int64(2), T.int64(3), T.int64(4), 
T.int64(5)), "float32"), T_reshape: T.Buffer((T.int64(2), T.int64(60)), 
"float32")):
+            T.func_attr({"tir.noalias": True})
+            for ax0, ax1 in T.grid(T.int64(2), T.int64(60)):
+                with T.block("T_reshape"):
+                    v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(x[(v_ax1 // T.int64(60) + v_ax0) % T.int64(2), 
v_ax1 % T.int64(60) // T.int64(20), v_ax1 % T.int64(20) // T.int64(5), v_ax1 % 
T.int64(5)])
+                    T.writes(T_reshape[v_ax0, v_ax1])
+                    T_reshape[v_ax0, v_ax1] = x[(v_ax1 // T.int64(60) + v_ax0) 
% T.int64(2), v_ax1 % T.int64(60) // T.int64(20), v_ax1 % T.int64(20) // 
T.int64(5), v_ax1 % T.int64(5)]
+    # fmt: on
+
+    mod = LegalizeOps()(BatchFlatten)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
+def test_batch_flatten_undefined_shape():
+    @tvm.script.ir_module
+    class BatchFlattenUndefinedShape:
+        @R.function
+        def main(x: R.Tensor(ndim=4, dtype="float32")) -> R.Tensor(ndim=2, 
dtype="float32"):
+            gv: R.Tensor(ndim=2, dtype="float32") = R.nn.batch_flatten(x)
+            return gv
+
+    mod = LegalizeOps()(BatchFlattenUndefinedShape)
+    tvm.ir.assert_structural_equal(mod, BatchFlattenUndefinedShape)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to