masahi commented on a change in pull request #4944: [Relay, Torch] Clean up and
refactor PyTorch frontend
URL: https://github.com/apache/incubator-tvm/pull/4944#discussion_r385433521
##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -718,293 +732,269 @@ def _convert_elemwise_input(data, input_type):
"aten::sqrt" : _sqrt()
}
-# Internal graph for parsing
-class Graph(object):
- """ A helper class for parsing PyTorch model to Relay graph."""
+def run_jit_passes(graph):
+ """ The inline pass is necessary to unwrap prim::CallMethod """
+ import torch
+ if version.parse(torch.__version__) >= version.parse("1.4.0"):
+ torch._C._jit_pass_inline(graph)
- def __init__(self, script_module, input_shapes):
- self._script_module = script_module
- self._graph = script_module.graph.copy()
+def is_int_seq(seq):
+ return len(seq) > 0 and all([isinstance(i, int) for i in seq])
- # TODO: Temporary fix to remove prim::CallMethod node introduced in PT
1.4
- import torch
- from packaging import version
- if version.parse(torch.__version__) >= version.parse("1.4.0"):
- torch._C._jit_pass_inline(self._graph)
-
- self._inputs_r = {}
- self._params = {}
- self._param_tensors = {}
- self._consts = {}
- self._ops = {}
- self._op_inputs_r = {}
- self._op_inputs_types = {}
- self._input_shapes = input_shapes if input_shapes else {}
- self._parsed_node_names = {}
-
- def from_pytorch(self):
- """ Construct relay nodes from PyTorch graph
-
- Currently only supports traced PyTorch format which means no control
flow.
- User must perform torch.jit.trace on a model and pass this in.
- Future support should include support scripted models
(torch.jit.script) which
- preserves control flow.
-
- Returns
- -------
- mod : tvm.relay.Module
- The module that optimizations will be performed on.
-
- params : dict of str to tvm.runtime
- Dict of converted parameters stored in tvm.runtime format
- """
- # Check for missing ops
- missing_operators = self._parse_import_prerequisites()
-
- if missing_operators:
- raise tvm.error.OpNotImplemented( \
- "The following operators are not implemented:
{}".format(missing_operators))
-
- # Translate PyTorch graph to by decorating Graph with state dict and
inputs into each op
- self._parse_inputs()
- self._parse_params()
- self._parse_ops()
-
- outputs = []
- nid = 0
-
- for op_name, op_node in self._ops.items():
- if op_node.kind() == "prim::ListConstruct":
- if any(inp.debugName() in self._parsed_node_names.keys() \
- for inp in op_node.inputs()):
- list_constr = []
- for i in op_node.inputs():
- if i.debugName() in self._parsed_node_names.keys():
- list_constr.append( \
-
outputs[self._parsed_node_names[i.debugName()]])
- elif i.node().kind() == "prim::Constant":
-
list_constr.append(int(self._consts[i.debugName()]))
- elif i.debugName() in self._inputs_r.keys():
-
list_constr.append(int(self._inputs_r[i.debugName()]))
-
- # Unwrap for tensors
- if len(list_constr) == 1:
- list_constr = list_constr[0]
-
- outputs.append(list_constr)
- self._parsed_node_names[op_name] = nid
- nid = nid+1
- elif op_node.kind() != "prim::Constant":
- for i in op_node.inputs():
- if i.debugName() in self._parsed_node_names.keys():
- for cnt in range(0, len(self._op_inputs_r[op_name])):
- if isinstance(self._op_inputs_r[op_name][cnt],
str):
- if "call/var" in
self._op_inputs_r[op_name][cnt]:
- self._op_inputs_r[op_name][cnt] = \
-
outputs[self._parsed_node_names[i.debugName()]]
- break
-
- call = _convert_map[op_node.kind()](self._op_inputs_r[op_name],
-
self._op_inputs_types[op_name])
-
- outputs.append(call)
- self._parsed_node_names[op_name] = nid
- nid = nid+1
-
- func = tvm.relay.Function(_analysis.free_vars(outputs[-1]),
outputs[-1])
-
- param = {k: tvm.nd.array(v) for k, v in self._param_tensors.items()}
-
- return _module.IRModule.from_expr(func), param
-
- def _parse_inputs(self):
- """ Map inputs to parser and inputs to graph. """
- # Get names and objects of inputs for IR
- ir_inputs = [i for i in self._graph.inputs()]
-
- # Create corresponding shape and add to input
- for input_name, ir_input in zip(self._input_shapes, ir_inputs[1:]):
- input_shape = self._input_shapes[input_name]
- ir_input.setDebugName(input_name)
-
- ir_dtype = _convert_data_type(ir_input.type().scalarType().lower())
- self._inputs_r[input_name] = _expr.var(input_name,
-
shape=self._input_shapes[input_name],
- dtype=ir_dtype)
-
- # Add self (first input of a PyTorch graph) to inputs, the value
doesn't matter here
- input_name = ir_inputs[0].debugName()
- self._inputs_r[input_name] = "self"
-
- def _parse_params(self):
- """ Map state dictionary values to corresponding prim::GetAttr op
node. """
- # Grab weights, biases, etc. from graph
- state_dict = self._script_module.state_dict()
- param_names = []
- for key, value in state_dict.items():
- param_str = str(key)
- param_name = param_str.split(".")[-1]
- param_names.append(param_name)
-
- # Get names of all inputs
- input_names = [i for i in self._inputs_r.keys()]
-
- # Iterate through graph for getAttr nodes and match full state_dict
name to nodes
- node_weight_map = {}
- for node in self._graph.nodes():
- if node.kind() == "prim::GetAttr":
-
- attribute_names = node.attributeNames()
- assert len(attribute_names) == 1
- node_getattr_name = node.s(attribute_names[0])
- node_arg = node.input().debugName()
-
- if node.outputsSize() == 1:
- node_name = node.output().debugName()
- else:
- node_name = [output.debugName() for output in
node.outputs()][0]
-
- if node_arg in input_names:
- node_weight_map[node_name] = node_getattr_name
- else:
- previous_map = node_weight_map[node_arg[:]]
- node_weight_map[node_name] =
previous_map+"."+node_getattr_name
-
- if node_getattr_name in param_names:
-
- value = state_dict[node_weight_map[node_name]]
- tensor = tvm.nd.array(value.cpu().numpy())
- shape = tensor.shape
- self._param_tensors[node_name] = tensor
-
- self._params[node_name] = _expr.var(node_name,
- shape=shape,
-
dtype=_convert_data_type(str(value.dtype)))
-
- def _parse_ops(self):
- """ Iterate through nodes and decorate graph with constants, operators,
- and the inputs to each operator. """
- # Traverse nodes and add to graph
- for node in self._graph.nodes():
-
- if node.outputsSize() == 1:
- node_name = node.output().debugName()
- else:
- node_name = [output.debugName() for output in
node.outputs()][0]
-
- if node.kind() == "prim::Constant":
- if node.hasAttributes():
- attribute_names = node.attributeNames()
- attr_name = attribute_names[0]
- ty = node.output().type().kind()
-
- if ty in ["IntType", "BoolType"]:
- self._consts[node_name] = node.i(attr_name)
- elif ty in ["FloatType", "LongType"]:
- self._consts[node_name] = node.f(attr_name)
- elif ty in ["TensorType", "CompleteTensorType"]:
- self._consts[node_name] = node.output().toIValue()
- else:
- self._consts[node_name] = "0"
- else:
- self._consts[node_name] = "0"
- elif node.kind() == "prim::ListConstruct":
- list_shape = []
- for input_node in node.inputs():
- if input_node.debugName() in self._inputs_r.keys():
- c = self._inputs_r[input_node.debugName()]
- assert isinstance(c, int)
- list_shape.append(c)
- elif input_node.debugName() in self._consts.keys():
- c = self._consts[input_node.debugName()]
- assert isinstance(c, int)
- list_shape.append(c)
- self._inputs_r[node_name] = _expr.var(node_name,
shape=list_shape)
-
- if node.kind() != "prim::GetAttr":
- self._add_op(node_name, node)
-
- # Graph Helper Functions
-
- def _add_op(self, node_id, op_node):
- """ Add an operator and its operators inputs to the graph and insert
placeholders
- where an input is a call node.
-
- Parameters
- ----------
- node_id : string
- The ID of the op node
-
- op_node : PyTorch Node object
- The full Node object for the op node
-
- """
- self._ops[(node_id)] = op_node
- input_list_r = []
- input_list_types = []
- for input_value in op_node.inputs():
-
- inode_id = input_value.debugName()
- inode = input_value.node()
-
- if inode_id in self._inputs_r.keys():
- input_list_r.append(self._inputs_r[inode_id])
- elif inode_id in self._params.keys():
- input_list_r.append(self._params[inode_id])
- elif inode.kind() == "prim::Constant":
- input_list_r.append(self._consts[inode_id])
+
+def get_tensor_and_var(torch_tensor, name):
Review comment:
Yeah, personally I find that convention ugly, but I'd blame python for this
:)
Will add '_' where appropriate.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services