This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4fad6fc27f [Relax] Make ShapeType ndim parameter mandatory (#18814)
4fad6fc27f is described below

commit 4fad6fc27f8bcace487cd57660443f57f5bc78bf
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Tue Feb 24 21:32:40 2026 +0800

    [Relax] Make ShapeType ndim parameter mandatory (#18814)
    
    ## Why
    
    Forces callers to be explicit about whether they want unknown ndim (-1)
    or a specific value. Prevents accidentally creating an unknown-ndim
    ShapeType when the caller intended to pass a concrete dimension,
    catching bugs at the call site rather than silently propagating unknown
    shapes.
    
    ## How
    
    - Remove default ndim from C++ and Python ShapeType constructors
    - Update tests
    
    Signed-off-by: Guan-Ming Chiu <[email protected]>
---
 include/tvm/relax/type.h               | 3 +--
 python/tvm/relax/ty.py                 | 7 +++----
 tests/python/relax/test_ast_printer.py | 4 ++--
 tests/python/relax/test_struct_info.py | 4 ++--
 4 files changed, 8 insertions(+), 10 deletions(-)

diff --git a/include/tvm/relax/type.h b/include/tvm/relax/type.h
index 8eaaf7bddc..32ec0f0f8f 100644
--- a/include/tvm/relax/type.h
+++ b/include/tvm/relax/type.h
@@ -53,8 +53,7 @@ class ShapeTypeNode : public TypeNode {
 
 class ShapeType : public Type {
  public:
-  // TODO(relax-team): remove the default value later.
-  TVM_DLL ShapeType(int ndim = kUnknownNDim, Span span = Span());
+  TVM_DLL ShapeType(int ndim, Span span = Span());
 
   TVM_FFI_DEFINE_OBJECT_REF_METHODS_NOTNULLABLE(ShapeType, Type, 
ShapeTypeNode);
 };
diff --git a/python/tvm/relax/ty.py b/python/tvm/relax/ty.py
index d7b619bf8d..afa25d0dd0 100644
--- a/python/tvm/relax/ty.py
+++ b/python/tvm/relax/ty.py
@@ -31,12 +31,11 @@ class ShapeType(Type):
 
     Parameters
     ----------
-    ndim : Optional[int]
-        The size of the shape.
+    ndim : int
+        The number of dimensions of the shape. Use -1 for unknown ndim.
     """
 
-    # TODO(relax-team): consider make ndim mandatory
-    def __init__(self, ndim: int = -1, span: Span = None) -> None:
+    def __init__(self, ndim: int, span: Span = None) -> None:
         self.__init_handle_by_constructor__(_ffi_api.ShapeType, ndim, span)  # 
type: ignore
 
 
diff --git a/tests/python/relax/test_ast_printer.py 
b/tests/python/relax/test_ast_printer.py
index d6e496c3f1..7c943e5d39 100644
--- a/tests/python/relax/test_ast_printer.py
+++ b/tests/python/relax/test_ast_printer.py
@@ -256,7 +256,7 @@ def test_shape_expr():
 
 def test_types():
     printer = ASTPrinter()
-    assert strip_whitespace(printer.visit_type_(rx.ShapeType())) == 
"ShapeType(ndim=-1)"
+    assert strip_whitespace(printer.visit_type_(rx.ShapeType(ndim=-1))) == 
"ShapeType(ndim=-1)"
     assert strip_whitespace(printer.visit_type_(rx.ShapeType(ndim=1))) == 
"ShapeType(ndim=1)"
     object_type = rx.ObjectType()
     assert strip_whitespace(printer.visit_type_(object_type)) == "ObjectType()"
@@ -266,7 +266,7 @@ def test_types():
     assert strip_whitespace(printer.visit_type_(tensor_type)) == 
"TensorType(ndim=2,dtype=int32)"
     unit_type = rx.TupleType([])
     assert strip_whitespace(printer.visit_type_(unit_type)) == 
"TupleType(fields=[])"
-    tuple_type = rx.TupleType([rx.ShapeType(), object_type])
+    tuple_type = rx.TupleType([rx.ShapeType(ndim=-1), object_type])
     assert_fields(
         "TupleType",
         {"fields": "[ShapeType(ndim=-1),ObjectType()]"},
diff --git a/tests/python/relax/test_struct_info.py 
b/tests/python/relax/test_struct_info.py
index 45599d198c..cf3202ddc9 100644
--- a/tests/python/relax/test_struct_info.py
+++ b/tests/python/relax/test_struct_info.py
@@ -52,8 +52,8 @@ def test_object_struct_info():
 
 
 def test_shape_type():
-    t0 = rx.ShapeType()
-    t1 = rx.ShapeType()
+    t0 = rx.ShapeType(ndim=-1)
+    t1 = rx.ShapeType(ndim=-1)
     assert t0 == t1
 
 

Reply via email to