This is an automated email from the ASF dual-hosted git repository.
moreau pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 4b2c01a [Parser] Add support for parsing the any dimension. (#6277)
4b2c01a is described below
commit 4b2c01a8fcba1f5941ccd18d2b1940fe8cefa7f1
Author: Jared Roesch <[email protected]>
AuthorDate: Fri Aug 14 09:13:42 2020 -0700
[Parser] Add support for parsing the any dimension. (#6277)
* Add case for any dimensions
* Fix second test case
---
src/parser/parser.cc | 5 +++--
tests/python/relay/test_ir_parser.py | 28 ++++++++++++++++++++++++++++
2 files changed, 31 insertions(+), 2 deletions(-)
diff --git a/src/parser/parser.cc b/src/parser/parser.cc
index 71d4304..8055d91 100644
--- a/src/parser/parser.cc
+++ b/src/parser/parser.cc
@@ -1502,6 +1502,8 @@ class Parser {
tvm::PrimExpr dim;
if (Peek()->token_type == TokenType::kMetaReference) {
dim = Downcast<tvm::PrimExpr>(ParseMetaRef());
+ } else if (WhenMatch(TokenType::kQuestion)) {
+ dim = tvm::tir::Any();
} else {
dim = Downcast<tvm::PrimExpr>(Match(TokenType::kInteger)->data);
}
@@ -1585,8 +1587,7 @@ class Parser {
return ParseNonPrimitiveType(tok);
}
}
- }
- if (WhenMatch(TokenType::kUnderscore)) {
+ } else if (WhenMatch(TokenType::kUnderscore)) {
return IncompleteType();
} else {
this->diag_ctx->EmitFatal(Diagnostic::Error(tok->span)
diff --git a/tests/python/relay/test_ir_parser.py
b/tests/python/relay/test_ir_parser.py
index 3fcc7da..6d581b6 100644
--- a/tests/python/relay/test_ir_parser.py
+++ b/tests/python/relay/test_ir_parser.py
@@ -591,6 +591,16 @@ def test_tensor_type():
)
)
+ assert_parses_as(
+ "let %_ : Tensor[(?, 1), float32] = (); ()",
+ relay.Let(
+ relay.Var("_", relay.TensorType((tvm.tir.Any(), 1), "float32")),
+ UNIT,
+ UNIT
+ )
+ )
+
+
def test_function_type():
assert_parses_as(
@@ -678,6 +688,24 @@ def test_adt_defn():
mod
)
+def test_adt_any():
+ code = """
+ type my_dtype {
+ my_cons(Tensor[(?, 1), uint16]),
+ }
+ """
+ mod = parse_module(code)
+ items = mod.type_definitions.items()
+ global_type_var, type_data = items[0]
+ assert global_type_var.name_hint == "my_dtype"
+ ctors = type_data.constructors
+ assert len(ctors) == 1
+ my_cons = ctors[0]
+ assert my_cons.name_hint == "my_cons"
+ ty_shape = my_cons.inputs[0].shape
+ assert isinstance(ty_shape[0], tvm.tir.Any)
+ assert ty_shape[1] == 1
+
def test_empty_adt_defn():
mod = tvm.IRModule()