This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 4c5c086 [BYOC][bugfix] Handle empty tuples in annotation pass (#7288) 4c5c086 is described below commit 4c5c086e2a259adeb486878c76c53896f3377fe8 Author: Steven S. Lyubomirsky <ss...@cs.washington.edu> AuthorDate: Fri Jan 15 09:05:42 2021 -0500 [BYOC][bugfix] Handle empty tuples in annotation pass (#7288) --- src/relay/transforms/annotate_target.cc | 5 +++-- tests/python/relay/test_pass_annotate_target.py | 26 +++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 76585cf..e365dca 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -144,11 +144,12 @@ class AnnotateTargetRewriter : public ExprRewriter { */ Expr new_expr = expr; const CallNode* call = expr.as<CallNode>(); + const TupleNode* tup = expr.as<TupleNode>(); if (op_expr_to_target_.find(expr) != op_expr_to_target_.end()) { // Check whether expr has args, if not - do not insert compiler_end. if (expr->IsInstance<RefWriteNode>() || expr->IsInstance<RefCreateNode>() || - expr->IsInstance<RefReadNode>() || expr->IsInstance<TupleNode>() || - expr->IsInstance<TupleGetItemNode>() || (call && !call->args.empty())) { + expr->IsInstance<RefReadNode>() || expr->IsInstance<TupleGetItemNode>() || + (call && !call->args.empty()) || (tup && !tup->fields.empty())) { std::string target = op_expr_to_target_[new_expr]; new_expr = InsertAnnotation(new_expr, target, make_end_op); op_expr_to_target_[new_expr] = target; diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 4f35066..ce86cc6 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -738,8 +738,8 @@ def test_if_free_vars(): mod = tvm.IRModule.from_expr(func) return mod - for annotate_non_call_ops in [True, False, True]: - result = transform.AnnotateTarget(target)(before()) + for annotate_non_call_ops in [True, False]: + result = transform.AnnotateTarget(target, annotate_non_call_ops)(before()) expected = transform.InferType()(after()) assert tvm.ir.structural_equal(expected, result) @@ -764,6 +764,27 @@ def test_free_vars_zeros(): assert tvm.ir.structural_equal(expected, result) +def test_empty_tuple(): + target = "test_empty_tuple" + + """An empty tuple should behave just like a call with no args (see above test).""" + + def before(): + func = relay.Function([], relay.Tuple([])) + mod = tvm.IRModule.from_expr(func) + return mod + + def after(): + func = relay.Function([], relay.Tuple([])) + mod = tvm.IRModule.from_expr(func) + return mod + + for annotate_non_call_ops in [True, False]: + result = transform.AnnotateTarget(target, annotate_non_call_ops)(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) + + if __name__ == "__main__": test_extern_dnnl() test_composite_function() @@ -780,3 +801,4 @@ if __name__ == "__main__": test_double_target() test_ends_with_tuple() test_ref_create_read_write() + test_empty_tuple()