This is an automated email from the ASF dual-hosted git repository. zhaowu pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new ce108c1 [Frontend] Add Span filling for frontends to Relay (#9723) ce108c1 is described below commit ce108c1f53235a483eb11dffffb8770642907642 Author: Chun-I Tsai <quic_chu...@quicinc.com> AuthorDate: Tue Dec 28 12:53:18 2021 +0800 [Frontend] Add Span filling for frontends to Relay (#9723) * [Frontend] Add Span filling for frontends to Relay * Add a common span filling feature for tf1/2, tflite and pytorch. * Add test case for Span filling in each frontend. * Expose Tuple and TupleGetItem to python end * [Frontend] Add Span filling for frontends to Relay * Fix lint errors * Change default string of scope_part in Pytorch * Reorder the span position for one to many conversion * [Frontend] Add Span filling for frontends to Relay * nit fixed * Add a bool flag to control print span * refactor pytorch get span to a birefer way * [Frontend] Add Span filling for frontends to Relay * Add one more condition for spanFller * Refine the format for those pytorch node without scopeName * [Frontend] Add Span filling for frontends to Relay * Fix lint --- python/tvm/relay/expr.py | 7 ++- python/tvm/relay/frontend/common.py | 53 +++++++++++++++++++++ python/tvm/relay/frontend/pytorch.py | 19 ++++++++ python/tvm/relay/frontend/tensorflow.py | 17 +------ python/tvm/relay/frontend/tensorflow2.py | 17 +------ python/tvm/relay/frontend/tflite.py | 16 +++++-- src/printer/relay_text_printer.cc | 23 ++++++--- src/printer/text_printer.h | 2 +- src/relay/ir/expr.cc | 4 +- tests/python/frontend/pytorch/test_forward.py | 47 +++++++++++++++++++ tests/python/frontend/tensorflow/test_forward.py | 54 ++++++++++++++++++++++ .../frontend/tensorflow2/test_sequential_models.py | 24 +++++++++- tests/python/frontend/tflite/test_forward.py | 54 ++++++++++++++++++++++ 13 files changed, 289 insertions(+), 48 deletions(-) diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 811e205..598354e 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -316,10 +316,13 @@ class TupleGetItem(ExprWithOp): index: int The index. + + span: Optional[tvm.relay.Span] + Span that points to original source code """ - def __init__(self, tuple_value, index): - self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index) + def __init__(self, tuple_value, index, span=None): + self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, tuple_value, index, span) @tvm._ffi.register_object("relay.RefCreate") diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 407afc4..be3d5ae 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -25,6 +25,7 @@ from tvm.ir import IRModule from tvm.topi.utils import get_const_tuple from .. import expr as _expr +from ..expr_functor import ExprMutator from .. import function as _function from .. import transform as _transform from .. import op as _op @@ -954,3 +955,55 @@ def try_resolve_var_to_const(x, graph_params): return _op.const(value, dtype) return x + + +def set_span(sym, node_name): + """Set up the sapn of relay expression(s) while converting OP""" + + class SpanFiller(ExprMutator): + """SpanFiller""" + + def __init__(self, node_name, suffix_str="_PART_"): + ExprMutator.__init__(self) + self.node_name = node_name + self.suffix_str = suffix_str + self.counter = 0 + self.distance_from_leaf = -1 + + def _create_span(self): + if self.distance_from_leaf == 0: + return tvm.relay.Span(tvm.relay.SourceName(self.node_name), 0, 0, 0, 0) + self.distance_from_leaf -= 1 + span_str = "{}{}{}".format(self.node_name, self.suffix_str, str(self.counter)) + self.counter += 1 + return tvm.relay.Span(tvm.relay.SourceName(span_str), 0, 0, 0, 0) + + def visit_call(self, call): + if call.span is None: + self.distance_from_leaf += 1 + new_args = [self.visit(arg) for arg in call.args] + return _expr.Call( + call.op, new_args, call.attrs, call.type_args, self._create_span() + ) + return call + + def visit_tuple(self, tup): + if tup.span is None: + self.distance_from_leaf += 1 + return _expr.Tuple([self.visit(field) for field in tup.fields], self._create_span()) + return tup + + def visit_tuple_getitem(self, op): + if op.span is None: + self.distance_from_leaf += 1 + return _expr.TupleGetItem(self.visit(op.tuple_value), op.index, self._create_span()) + return op + + def fill(self, sym): + if isinstance(sym, _expr.TupleWrapper): + return _expr.TupleWrapper(self.visit(sym.tuple_value), sym.size) + if isinstance(sym, _expr.RelayExpr): + return self.visit(sym) + return sym + + return SpanFiller(node_name).fill(sym) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 24ccad5..6e8ad68 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -45,6 +45,7 @@ from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value from .common import infer_value_simulated as _infer_value_simulated from .common import lstm_cell, try_infer_value, unbind +from .common import set_span from .pytorch_utils import is_version_greater_than __all__ = ["from_pytorch"] @@ -3271,6 +3272,9 @@ class PyTorchOpConverter: def convert_operators(self, operators, outputs, ret_names): """Convert each Torch IR operators to Relay equivalent""" + # an op node might not belong to any of scope in trace info natively + # use a cunter to prevent from messing up its scope in span + empty_counter = 0 for node_name, op_node in operators: operator = op_node.kind() inputs = _get_op_inputs(op_node, outputs) @@ -3308,6 +3312,9 @@ class PyTorchOpConverter: relay_out = relay_op( inputs, _get_input_types(op_node, outputs, default_dtype=self.default_dtype) ) + span_str, empty_counter = self._get_torch_span(op_node, empty_counter) + relay_out = set_span(relay_out, span_str) + self.record_output_type(relay_out) if isinstance(relay_out, tuple): @@ -3321,6 +3328,18 @@ class PyTorchOpConverter: return [_wrap_const(outputs[ret_name]) for ret_name in ret_names] + def _get_torch_span(self, node, empty_counter): + # torch span looks like + # %input.5 : Float(...) = aten::relu_(%input.3), scope: __module.relu # ${torch}/nn file + # the scope part might not exist + if node.scopeName(): + scope_name_str = "jit._trace.TopLevelTracedModule: " + node.scopeName() + else: + scope_name_str = "warning: no trace info " + str(empty_counter) + empty_counter += 1 + span_str = "C.graph: {}, {}".format(node.kind(), scope_name_str) + return span_str, empty_counter + def _pytorch_result_type(dtypes, non_tensor_inputs): """This promotes TVM dtypes like PyTorch would""" diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index d35e0e1..c2aa5a1 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -37,6 +37,7 @@ from .common import get_relay_op from .common import infer_type as _infer_type from .common import infer_shape as _infer_shape from .common import infer_value as _infer_value +from .common import set_span from .tensorflow_ops import _convert_map from .tensorflow_ops import _need_prelude_for_shape_inference @@ -1028,24 +1029,10 @@ class GraphProto(object): else: raise NotImplementedError("Operator {} not implemented.".format(op_name)) - sym = self._set_span(sym, node_name) + sym = set_span(sym, node_name) return sym - @staticmethod - def _set_span(sym, node_name): - span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0) - if isinstance(sym, _expr.Call) and sym.span is None: - sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span) - elif isinstance(sym, _expr.TupleWrapper): - tuple_value = sym.tuple_value - if isinstance(tuple_value, _expr.Call) and tuple_value.span is None: - tuple_value = _expr.Call( - tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span - ) - sym = _expr.TupleWrapper(tuple_value, sym.size) - return sym - def _licm_construct(self, loop_name, node_name): """Construct a node by considering whether it is loop invariant with the given while loop. If yes, we diff --git a/python/tvm/relay/frontend/tensorflow2.py b/python/tvm/relay/frontend/tensorflow2.py index 465f530..2c8b7d4 100644 --- a/python/tvm/relay/frontend/tensorflow2.py +++ b/python/tvm/relay/frontend/tensorflow2.py @@ -36,6 +36,7 @@ from .. import analysis from .. import function as _function from ..loops import while_loop as _while_loop from .common import infer_type as _infer_type +from .common import set_span from .tensorflow_ops import _convert_map as _convert_map_common from .tensorflow_ops import _get_more_static_shape_rank @@ -58,22 +59,6 @@ def _infer_type_with_prelude(val, prelude): return body.checked_type -def set_span(sym, node_name): - """set span of symbol""" - - span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0) - if isinstance(sym, _expr.Call): - sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span) - elif isinstance(sym, _expr.TupleWrapper): - tuple_value = sym.tuple_value - if isinstance(tuple_value, _expr.Call): - tuple_value = _expr.Call( - tuple_value.op, tuple_value.args, tuple_value.attrs, tuple_value.type_args, span - ) - sym = _expr.TupleWrapper(tuple_value, sym.size) - return sym - - def is_tensor_list_constuctor(tf_node): """Check whether is tensor list constructor node.""" return tf_node.op == "TensorListReserve" diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index f0f20e1..b0b2bd3 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -32,6 +32,7 @@ from .. import op as _op from .. import qnn as _qnn from .common import ExprTable from .common import infer_shape as _infer_shape +from .common import set_span from .common import to_int_list from .tflite_flexbuffer import FlexBufferDecoder @@ -239,12 +240,17 @@ class OperatorConverter(object): if len(output_tensors) == 1: tensor_idx = output_tensors[0].tensor_idx - self.exp_tab.set_expr(get_tensor_name(self.subgraph, tensor_idx), ret) + curr_output = get_tensor_name(self.subgraph, tensor_idx) + ret = set_span(ret, "location: {}, output_name: {}".format(op_idx, curr_output)) + self.exp_tab.set_expr(curr_output, ret) else: - for idx, output_tensor in enumerate(output_tensors): - self.exp_tab.set_expr( - get_tensor_name(self.subgraph, output_tensor.tensor_idx), ret[idx] - ) + out_names = [] + for output_tensor in output_tensors: + out_names.append(get_tensor_name(self.subgraph, output_tensor.tensor_idx)) + curr_output = ", ".join(out_names) + ret = set_span(ret, "location: {}, output_name: {}".format(op_idx, curr_output)) + for idx, out_name in enumerate(out_names): + self.exp_tab.set_expr(out_name, ret[idx]) def get_op_code_str(self, op): """Get TFLite ops string representation""" diff --git a/src/printer/relay_text_printer.cc b/src/printer/relay_text_printer.cc index fdc6c37..7654ef1 100644 --- a/src/printer/relay_text_printer.cc +++ b/src/printer/relay_text_printer.cc @@ -389,12 +389,21 @@ Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) { if (op->fields.size() == 1) { doc << ","; } - return doc << ")"; + doc << ")"; + if (op->span.defined()) { + doc << " /* " << PrintSpan(op->span) << " */"; + } + return doc; } Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) { Doc doc; - return doc << Print(op->tuple) << "." << op->index; + doc << Print(op->tuple) << "." << op->index; + + if (op->span.defined()) { + doc << " /* " << PrintSpan(op->span) << " */"; + } + return doc; } Doc RelayTextPrinter::VisitExpr_(const IfNode* op) { @@ -968,11 +977,13 @@ Doc RelayTextPrinter::PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& return doc; } -Doc RelayTextPrinter::PrintSpan(const Span& span) { +Doc RelayTextPrinter::PrintSpan(const Span& span, bool include_spans) { Doc doc; - const auto* span_node = span.as<SpanNode>(); - ICHECK(span_node); - doc << span_node->source_name->name; + if (include_spans) { + const auto* span_node = span.as<SpanNode>(); + ICHECK(span_node); + doc << span_node->source_name->name; + } return doc; } diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h index a4d0ff3..ca46700 100644 --- a/src/printer/text_printer.h +++ b/src/printer/text_printer.h @@ -113,7 +113,7 @@ class RelayTextPrinter : public ExprFunctor<Doc(const Expr&)>, */ Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map); - Doc PrintSpan(const Span& span); + Doc PrintSpan(const Span& span, bool include_spans = true); Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false); diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index b680a49..f8cb4f0 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -362,8 +362,8 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, Optional<Expr> opt_tuple, TVM_REGISTER_NODE_TYPE(TupleGetItemNode); -TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index) { - return TupleGetItem(tuple, index); +TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int index, Span span) { + return TupleGetItem(tuple, index, span); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 86970bf..a64fa0b 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -247,6 +247,53 @@ def verify_model( torch.cuda.empty_cache() +def verify_span(model_name, input_data=[], custom_convert_map={}): + if isinstance(model_name, str): + baseline_model, baseline_input = load_model(model_name) + elif isinstance(input_data, list): + baseline_model = model_name + baseline_input = input_data + elif isinstance(input_data, torch.Tensor) or len(input_data.shape) == 0: + baseline_model = model_name + baseline_input = [input_data] + else: + assert False, "Unexpected input format" + + trace = torch.jit.trace(baseline_model, [input.clone() for input in baseline_input]) + if isinstance(baseline_model, torch.nn.Module): + trace = trace.float().eval() + + if torch.cuda.is_available(): + trace = trace.cuda() + else: + trace = trace.cpu() + + input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] + input_shapes = list(zip(input_names, [inp.shape for inp in baseline_input])) + mod, params = relay.frontend.from_pytorch(trace, input_shapes, custom_convert_map) + + # collect fail cases for the convenience of further improvement + fail_cases = [] + mod_main_start = False + for line in str(mod.__str__).split("\n"): + if "@main" in line: + mod_main_start = True + continue + + if mod_main_start == True: + if "}" == line: + break + elif not ("/*" in line and "*/" in line): + fail_cases.append(line) + + print(fail_cases) + assert len(fail_cases) == 0 + + +def test_span(): + verify_span("resnet18") + + # Single operator tests @tvm.testing.uses_gpu def test_forward_pixel_shuffle(): diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 338d219..be32ca3 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -298,6 +298,60 @@ def is_gpu_available(): return False +def verify_span(mod): + # collect fail cases for the convenience of further improvement + fail_cases = [] + mod_main_start = False + for line in str(mod.__str__).split("\n"): + if "@main" in line: + mod_main_start = True + continue + + if mod_main_start == True: + if "}" == line: + break + elif not ("/*" in line and "*/" in line): + fail_cases.append(line) + + print(fail_cases) + assert len(fail_cases) == 0 + + +def simple_model(): + input_node = tf.placeholder(shape=[None, None, 3, 1], dtype=np.float32, name="input") + + shape = tf.shape(input_node) + stack = tf.stack([shape[0], 3, 3], axis=0) + output_node = tf.reshape(input_node, stack, name="output") + return output_node + + +####################################################################### +# Span fill up +# ------- +def test_span_complement_simple_model(): + with tf.Graph().as_default() as graph: + model_graph = simple_model() + graph_def = graph.as_graph_def() + + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + + mod, params = relay.frontend.from_tensorflow(graph_def, shape={"input:0", (1, 3, 3, 1)}) + verify_span(mod) + + +def test_span_complement_big_model(): + with tf.Graph().as_default() as graph: + graph_def = tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb") + # Call the utility to import the graph definition into default graph. + graph_def = tf_testing.ProcessGraphDefParam(graph_def) + + mod, params = relay.frontend.from_tensorflow( + graph_def, shape={"input_tensor:0", (128, 224, 224, 3)} + ) + verify_span(mod) + + ####################################################################### # Pooling # ------- diff --git a/tests/python/frontend/tensorflow2/test_sequential_models.py b/tests/python/frontend/tensorflow2/test_sequential_models.py index 1b5a634..b76b4a7 100644 --- a/tests/python/frontend/tensorflow2/test_sequential_models.py +++ b/tests/python/frontend/tensorflow2/test_sequential_models.py @@ -26,6 +26,25 @@ from tensorflow.python.framework.convert_to_constants import convert_variables_t from common import compare_tf_tvm from common import run_tf_code +from tvm.relay.frontend.tensorflow2 import from_tensorflow + + +def verify_span(mod): + fail_cases = [] + mod_main_start = False + for line in str(mod.__str__).split("\n"): + if "@main" in line: + mod_main_start = True + continue + + if mod_main_start == True: + if "}" == line: + break + elif not ("/*" in line and "*/" in line): + fail_cases.append(line) + + print(fail_cases) + assert len(fail_cases) == 0 def run_sequential_model(model_fn, input_shape): @@ -48,7 +67,10 @@ def run_sequential_model(model_fn, input_shape): gdef = f.graph.as_graph_def(add_shapes=True) return gdef, _input, _output - compare_tf_tvm(*model_graph(model_fn, input_shape), runtime="vm") + gdef, _input, _output = model_graph(model_fn, input_shape) + mod, _ = from_tensorflow(gdef) + compare_tf_tvm(gdef, _input, _output, runtime="vm") + verify_span(mod) def test_dense_model(): diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 545315a..d234cd1 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -259,6 +259,59 @@ def run_tflite_graph(tflite_model_buf, input_data): return tflite_output +def run_span_verification( + tflite_model_buf, + input_data, + input_node, + num_output=1, + target="llvm", + out_names=None, + mode="graph_executor", +): + """Generic function to compile on relay and execute on tvm""" + # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 + try: + import tflite.Model + + tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) + except AttributeError: + import tflite + + tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0) + except ImportError: + raise ImportError("The tflite package must be installed") + + input_data = convert_to_list(input_data) + input_node = convert_to_list(input_node) + + shape_dict = {} + dtype_dict = {} + for i, e in enumerate(input_node): + shape_dict[e] = input_data[i].shape + dtype_dict[e] = input_data[i].dtype.name + + mod, _ = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict) + verify_span(mod) + + +def verify_span(mod): + fail_cases = [] + mod_main_start = False + for line in str(mod.__str__).split("\n"): + if "@main" in line: + mod_main_start = True + continue + + if mod_main_start == True: + if "}" == line: + break + elif not ("/*" in line and "*/" in line): + fail_cases.append(line) + + print(fail_cases) + assert len(fail_cases) == 0 + + def compare_tflite_with_tvm( in_data, in_name, @@ -4507,6 +4560,7 @@ def test_forward_tflite2_qnn_resnet50(): tflite_output = run_tflite_graph(tflite_model_buf, data) tflite_predictions = np.squeeze(tflite_output) tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1] + run_span_verification(tflite_model_buf, np.array(data), "input_1") tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), "input_1") tvm_predictions = np.squeeze(tvm_output) tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]