masahi commented on a change in pull request #4964: [Torch] Add initial control 
flow support 
URL: https://github.com/apache/incubator-tvm/pull/4964#discussion_r386067505
 
 

 ##########
 File path: python/tvm/relay/frontend/pytorch.py
 ##########
 @@ -955,7 +1025,100 @@ def parse_params(graph, state_dict):
     return params, param_tensors
 
 
-def parse_operators(operators, outputs, output_index_map, ret_name):
+def convert_block(block, outputs, output_index_map):
+    """ Translate Torch "Block", used for prim::If and prim::Loop """
+    ops = _get_operator_nodes(block.nodes())
+    ret_names = _get_input_names(block.returnNode())
+    return convert_operators(ops, outputs, output_index_map, ret_names)
+
+
+def convert_if(if_node, outputs, output_index_map):
+    """ Translate Torch prim::If to Relay If """
+    cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]]
+    blocks = list(if_node.blocks())
+    true_branch = convert_block(blocks[0], outputs, output_index_map)
+    false_branch = convert_block(blocks[1], outputs, output_index_map)
+    assert len(true_branch) == 1 and len(false_branch) == 1
+    return _expr.If(cond, true_branch[0], false_branch[0])
+
+
+def convert_loop(loop_node, outputs, output_index_map):
+    """ Translate Torch prim::Loop to Relay while_loop """
+    def get_input(index):
+        ivalue = loop_node.inputsAt(index)
+        inode = ivalue.node()
+        if inode.kind() == "prim::Constant":
+            return _expr.const(_get_constant(inode))
+        var_name = ivalue.debugName()
+        assert var_name in output_index_map
+        return _wrap_const(outputs[output_index_map[var_name]])
+
+    # Refer to the spec for prim::Loop below
+    # 
https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
+    # The first input: %max_trip_count
+    # The second input: %initial_condition
+    # The rest of input: loop variables
+    max_loop_count = get_input(0)
 
 Review comment:
   yes, my test case has `max_loop_count` being the output of `aten::size(...)`.
   ```
     ...
     %5 : int = aten::size(%inp.1, %a.1) # test_forward.py:798:27
     %a : int = prim::Loop(%5, %2, %a.1) # test_forward.py:798:12
    ...
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to