This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4fad6fc27f [Relax] Make ShapeType ndim parameter mandatory (#18814)
4fad6fc27f is described below
commit 4fad6fc27f8bcace487cd57660443f57f5bc78bf
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Feb 24 21:32:40 2026 +0800
[Relax] Make ShapeType ndim parameter mandatory (#18814)
## Why
Forces callers to be explicit about whether they want unknown ndim (-1)
or a specific value. Prevents accidentally creating an unknown-ndim
ShapeType when the caller intended to pass a concrete dimension,
catching bugs at the call site rather than silently propagating unknown
shapes.
## How
- Remove default ndim from C++ and Python ShapeType constructors
- Update tests
Signed-off-by: Guan-Ming Chiu <[email protected]>
---
include/tvm/relax/type.h | 3 +--
python/tvm/relax/ty.py | 7 +++----
tests/python/relax/test_ast_printer.py | 4 ++--
tests/python/relax/test_struct_info.py | 4 ++--
4 files changed, 8 insertions(+), 10 deletions(-)
diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h
index 8eaaf7bddc..32ec0f0f8f 100644
--- a/include/tvm/relax/type.h
+++ b/include/tvm/relax/type.h
@@ -53,8 +53,7 @@ class ShapeTypeNode : public TypeNode {
class ShapeType : public Type {
public:
- // TODO(relax-team): remove the default value later.
- TVM_DLL ShapeType(int ndim = kUnknownNDim, Span span = Span());
+ TVM_DLL ShapeType(int ndim, Span span = Span());
TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ShapeType, Type,
ShapeTypeNode);
};
diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py
index d7b619bf8d..afa25d0dd0 100644
--- a/python/tvm/relax/ty.py
+++ b/python/tvm/relax/ty.py
@@ -31,12 +31,11 @@ class ShapeType(Type):
Parameters
----------
- ndim : Optional[int]
- The size of the shape.
+ ndim : int
+ The number of dimensions of the shape. Use -1 for unknown ndim.
"""
- # TODO(relax-team): consider make ndim mandatory
- def __init__(self, ndim: int = -1, span: Span = None) -> None:
+ def __init__(self, ndim: int, span: Span = None) -> None:
self.__init_handle_by_constructor__(_ffi_api.ShapeType, ndim, span) #
type: ignore
diff --git a/tests/python/relax/test_ast_printer.py
b/tests/python/relax/test_ast_printer.py
index d6e496c3f1..7c943e5d39 100644
--- a/tests/python/relax/test_ast_printer.py
+++ b/tests/python/relax/test_ast_printer.py
@@ -256,7 +256,7 @@ def test_shape_expr():
def test_types():
printer = ASTPrinter()
- assert strip_whitespace(printer.visit_type_(rx.ShapeType())) ==
"ShapeType(ndim=-1)"
+ assert strip_whitespace(printer.visit_type_(rx.ShapeType(ndim=-1))) ==
"ShapeType(ndim=-1)"
assert strip_whitespace(printer.visit_type_(rx.ShapeType(ndim=1))) ==
"ShapeType(ndim=1)"
object_type = rx.ObjectType()
assert strip_whitespace(printer.visit_type_(object_type)) == "ObjectType()"
@@ -266,7 +266,7 @@ def test_types():
assert strip_whitespace(printer.visit_type_(tensor_type)) ==
"TensorType(ndim=2,dtype=int32)"
unit_type = rx.TupleType([])
assert strip_whitespace(printer.visit_type_(unit_type)) ==
"TupleType(fields=[])"
- tuple_type = rx.TupleType([rx.ShapeType(), object_type])
+ tuple_type = rx.TupleType([rx.ShapeType(ndim=-1), object_type])
assert_fields(
"TupleType",
{"fields": "[ShapeType(ndim=-1),ObjectType()]"},
diff --git a/tests/python/relax/test_struct_info.py
b/tests/python/relax/test_struct_info.py
index 45599d198c..cf3202ddc9 100644
--- a/tests/python/relax/test_struct_info.py
+++ b/tests/python/relax/test_struct_info.py
@@ -52,8 +52,8 @@ def test_object_struct_info():
def test_shape_type():
- t0 = rx.ShapeType()
- t1 = rx.ShapeType()
+ t0 = rx.ShapeType(ndim=-1)
+ t1 = rx.ShapeType(ndim=-1)
assert t0 == t1