This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fed71ef6a6 [Relax] Add native size operator (#18667)
fed71ef6a6 is described below

commit fed71ef6a69facc6031144959f191cf70e963a67
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Jan 20 21:34:36 2026 +0800

    [Relax] Add native size operator (#18667)
    
    ## Why
    
    ONNX models use the Size operator to get total element count of a
    tensor. Relax didn't have a native equivalent.
    
    ## How
    
    - Adds R.size(tensor) operator that returns the total number of elements
    in a tensor as a scalar int64
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py    |  3 +-
 python/tvm/relax/op/__init__.py                    |  1 +
 python/tvm/relax/op/base.py                        | 28 ++++++++--
 .../tvm/relax/transform/legalize_ops/inspect_op.py |  6 +++
 python/tvm/script/ir_builder/relax/ir.py           |  2 +
 src/relax/op/op.cc                                 | 26 +++++++++
 tests/python/relax/test_op_size.py                 | 63 ++++++++++++++++++++++
 7 files changed, 122 insertions(+), 7 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index e14e2ed956..9968eb5ed8 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -911,8 +911,7 @@ class Size(OnnxOpConverter):
 
     @classmethod
     def _impl_v1(cls, bb, inputs, attr, params):
-        # TODO(tvm-team): add native support for size op
-        return 
relax.op.prod(relax.op.shape_to_tensor(relax.op.shape_of(inputs[0])))
+        return relax.op.size(inputs[0])
 
 
 class EyeLike(OnnxOpConverter):
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index c6504d79c9..2ebca3811f 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -40,6 +40,7 @@ from .base import (
     register_gradient,
     shape_of,
     shape_to_tensor,
+    size,
     tensor_to_shape,
     to_vdevice,
 )
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index ffa19fbaa0..d46aa883f0 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -634,6 +634,22 @@ def shape_of(expr: Expr) -> Expr:
     return _ffi_api.shape_of(expr)  # type: ignore # pylint: disable=no-member
 
 
+def size(expr: Expr) -> Expr:
+    """Get the total number of elements in a tensor.
+
+    Parameters
+    ----------
+    expr : Expr
+        The input tensor.
+
+    Returns
+    -------
+    result : Expr
+        A scalar tensor of dtype int64 containing the total number of elements.
+    """
+    return _ffi_api.size(expr)  # type: ignore # pylint: disable=no-member
+
+
 def tensor_to_shape(expr: Expr) -> Expr:
     """Convert tensor to shape expr.
     Parameters
@@ -777,11 +793,13 @@ def call_pure_packed(
         sinfo_args = [sinfo_args]
 
     sinfo_args = [
-        sinfo()
-        if callable(sinfo)
-        else sinfo.asobject()
-        if isinstance(sinfo, ObjectConvertible)
-        else sinfo
+        (
+            sinfo()
+            if callable(sinfo)
+            else sinfo.asobject()
+            if isinstance(sinfo, ObjectConvertible)
+            else sinfo
+        )
         for sinfo in sinfo_args
     ]
 
diff --git a/python/tvm/relax/transform/legalize_ops/inspect_op.py 
b/python/tvm/relax/transform/legalize_ops/inspect_op.py
index e031386e6e..a41c74cae0 100644
--- a/python/tvm/relax/transform/legalize_ops/inspect_op.py
+++ b/python/tvm/relax/transform/legalize_ops/inspect_op.py
@@ -23,6 +23,7 @@ from tvm.script import tir as T
 
 from ...block_builder import BlockBuilder
 from ...expr import Call, Expr
+from ... import op
 from .common import register_legalize
 
 
@@ -126,3 +127,8 @@ def _tensor_elem_offset(bb: BlockBuilder, call: Call) -> 
Expr:
 
     gvar = bb.add_func(_get_tensor_elem_offset, "_get_tensor_elem_offset")
     return Call(gvar, call.args)
+
+
+@register_legalize("relax.size")
+def _size(_bb: BlockBuilder, call: Call) -> Expr:
+    return op.prod(op.shape_to_tensor(op.shape_of(call.args[0])))
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index e0a009a94e..5410c3c03a 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -163,6 +163,7 @@ from tvm.relax.op import (
     sign,
     sin,
     sinh,
+    size,
     slice_scatter,
     sort,
     split,
@@ -938,6 +939,7 @@ __all__ = [
     "shape",
     "shape_of",
     "ShapeExpr",
+    "size",
     "std",
     "str",
     "sum",
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 3acfb53b27..d7d68766dd 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -1125,6 +1125,32 @@ TVM_FFI_STATIC_INIT_BLOCK() {
   refl::GlobalDef().def("relax.op.shape_of", MakeShapeOf);
 }
 
+// size
+
+StructInfo InferStructInfoSize(const Call& call, const BlockBuilder& ctx) {
+  auto arg_sinfo = GetStructInfo(call->args[0]);
+  auto* tensor_sinfo = GetStructInfo(call->args[0]).as<TensorStructInfoNode>();
+  CHECK(tensor_sinfo) << "size expects a tensor input, but received " << 
arg_sinfo
+                      << "; use MatchCast if necessary";
+  return TensorStructInfo(ShapeExpr(ffi::Array<PrimExpr>{}), 
DataType::Int(64));
+}
+
+TVM_REGISTER_OP("relax.size")
+    .set_num_inputs(1)
+    .add_argument("input", "Expr", "The input tensor")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSize)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+Expr MakeSize(Expr expr) {
+  static const Op& op = Op::Get("relax.size");
+  return Call(op, {expr}, {}, {});
+}
+
+TVM_FFI_STATIC_INIT_BLOCK() {
+  namespace refl = tvm::ffi::reflection;
+  refl::GlobalDef().def("relax.op.size", MakeSize);
+}
+
 // tensor_to_shape
 
 StructInfo ReturnTensorToShapeStructInfo(const Call& call, const BlockBuilder& 
ctx) {
diff --git a/tests/python/relax/test_op_size.py 
b/tests/python/relax/test_op_size.py
new file mode 100644
index 0000000000..77c5ebef5a
--- /dev/null
+++ b/tests/python/relax/test_op_size.py
@@ -0,0 +1,63 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+
+import tvm
+import tvm.testing
+from tvm import relax
+from tvm.script import relax as R
+
+
+def test_op_size():
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor((2, 3), "float32")) -> R.Tensor((), "int64"):
+            return R.size(x)
+
+    x_np = np.random.rand(2, 3).astype("float32")
+    x = tvm.runtime.tensor(x_np)
+
+    target = tvm.target.Target("llvm")
+    ex = relax.build(Module, target)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    res = vm["main"](x)
+    assert res.numpy() == 6
+
+
+def test_op_size_dynamic():
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor(("m", "n"), "float32")) -> R.Tensor((), "int64"):
+            return R.size(x)
+
+    x_np = np.random.rand(4, 5).astype("float32")
+    x = tvm.runtime.tensor(x_np)
+
+    target = tvm.target.Target("llvm")
+    ex = relax.build(Module, target)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+
+    res = vm["main"](x)
+    assert res.numpy() == 20
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to