This is an automated email from the ASF dual-hosted git repository.

moreau pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 4b2c01a  [Parser] Add support for parsing the any dimension.  (#6277)
4b2c01a is described below

commit 4b2c01a8fcba1f5941ccd18d2b1940fe8cefa7f1
Author: Jared Roesch <[email protected]>
AuthorDate: Fri Aug 14 09:13:42 2020 -0700

    [Parser] Add support for parsing the any dimension.  (#6277)
    
    * Add case for any dimensions
    
    * Fix second test case
---
 src/parser/parser.cc                 |  5 +++--
 tests/python/relay/test_ir_parser.py | 28 ++++++++++++++++++++++++++++
 2 files changed, 31 insertions(+), 2 deletions(-)

diff --git a/src/parser/parser.cc b/src/parser/parser.cc
index 71d4304..8055d91 100644
--- a/src/parser/parser.cc
+++ b/src/parser/parser.cc
@@ -1502,6 +1502,8 @@ class Parser {
           tvm::PrimExpr dim;
           if (Peek()->token_type == TokenType::kMetaReference) {
             dim = Downcast<tvm::PrimExpr>(ParseMetaRef());
+          } else if (WhenMatch(TokenType::kQuestion)) {
+            dim = tvm::tir::Any();
           } else {
             dim = Downcast<tvm::PrimExpr>(Match(TokenType::kInteger)->data);
           }
@@ -1585,8 +1587,7 @@ class Parser {
           return ParseNonPrimitiveType(tok);
         }
       }
-    }
-    if (WhenMatch(TokenType::kUnderscore)) {
+    } else if (WhenMatch(TokenType::kUnderscore)) {
       return IncompleteType();
     } else {
       this->diag_ctx->EmitFatal(Diagnostic::Error(tok->span)
diff --git a/tests/python/relay/test_ir_parser.py 
b/tests/python/relay/test_ir_parser.py
index 3fcc7da..6d581b6 100644
--- a/tests/python/relay/test_ir_parser.py
+++ b/tests/python/relay/test_ir_parser.py
@@ -591,6 +591,16 @@ def test_tensor_type():
         )
     )
 
+    assert_parses_as(
+        "let %_ : Tensor[(?, 1), float32] = (); ()",
+        relay.Let(
+            relay.Var("_", relay.TensorType((tvm.tir.Any(), 1), "float32")),
+            UNIT,
+            UNIT
+        )
+    )
+
+
 
 def test_function_type():
     assert_parses_as(
@@ -678,6 +688,24 @@ def test_adt_defn():
         mod
     )
 
+def test_adt_any():
+    code = """
+    type my_dtype {
+        my_cons(Tensor[(?, 1), uint16]),
+    }
+    """
+    mod = parse_module(code)
+    items = mod.type_definitions.items()
+    global_type_var, type_data = items[0]
+    assert global_type_var.name_hint == "my_dtype"
+    ctors = type_data.constructors
+    assert len(ctors) == 1
+    my_cons = ctors[0]
+    assert my_cons.name_hint == "my_cons"
+    ty_shape = my_cons.inputs[0].shape
+    assert isinstance(ty_shape[0], tvm.tir.Any)
+    assert ty_shape[1] == 1
+
 
 def test_empty_adt_defn():
     mod = tvm.IRModule()

Reply via email to