zhiics commented on a change in pull request #4964: [Torch] Add initial control flow support URL: https://github.com/apache/incubator-tvm/pull/4964#discussion_r386065362
########## File path: python/tvm/relay/frontend/pytorch.py ########## @@ -955,7 +1025,100 @@ def parse_params(graph, state_dict): return params, param_tensors -def parse_operators(operators, outputs, output_index_map, ret_name): +def convert_block(block, outputs, output_index_map): + """ Translate Torch "Block", used for prim::If and prim::Loop """ + ops = _get_operator_nodes(block.nodes()) + ret_names = _get_input_names(block.returnNode()) + return convert_operators(ops, outputs, output_index_map, ret_names) + + +def convert_if(if_node, outputs, output_index_map): + """ Translate Torch prim::If to Relay If """ + cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]] + blocks = list(if_node.blocks()) + true_branch = convert_block(blocks[0], outputs, output_index_map) + false_branch = convert_block(blocks[1], outputs, output_index_map) + assert len(true_branch) == 1 and len(false_branch) == 1 + return _expr.If(cond, true_branch[0], false_branch[0]) + + +def convert_loop(loop_node, outputs, output_index_map): + """ Translate Torch prim::Loop to Relay while_loop """ + def get_input(index): + ivalue = loop_node.inputsAt(index) + inode = ivalue.node() + if inode.kind() == "prim::Constant": + return _expr.const(_get_constant(inode)) + var_name = ivalue.debugName() + assert var_name in output_index_map + return _wrap_const(outputs[output_index_map[var_name]]) + + # Refer to the spec for prim::Loop below + # https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops + # The first input: %max_trip_count + # The second input: %initial_condition + # The rest of input: loop variables + max_loop_count = get_input(0) + init_cond = get_input(1) + num_loop_var = len(list(loop_node.inputs())) - 2 + init_vals = [get_input(i + 2) for i in range(num_loop_var)] + + # For loop (not while loop) has always %initial_condition being 1 + is_for_loop = isinstance(init_cond, _expr.Constant) + + if is_for_loop: + loop_iter_dtype = "int32" + # always count from 0 + init_loop_iter_val = _expr.const(0, dtype="int32") + else: + loop_iter_dtype = "bool" + init_loop_iter_val = init_cond + + body_block = list(loop_node.blocks())[0] + inames = _get_input_names(body_block) + loop_input_vals = [init_loop_iter_val] + init_vals + name_val_pairs = list(zip(inames, loop_input_vals)) + _update_outputs_from_pairs(name_val_pairs, outputs, output_index_map) + + def cond(*current_vals): + i = current_vals[0] + + if is_for_loop: + return _op.less(i, max_loop_count) + + return _op.equal(i, _expr.const(True, 'bool')) + + def body(*current_vals): + # Update loop variables using the prev iteration outputs Review comment: let's assert the equivalence of the `inames` and `current_vals` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services