This is an automated email from the ASF dual-hosted git repository.

zhaowu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ce108c1  [Frontend] Add Span filling for frontends to Relay (#9723)
ce108c1 is described below

commit ce108c1f53235a483eb11dffffb8770642907642
Author: Chun-I Tsai <quic_chu...@quicinc.com>
AuthorDate: Tue Dec 28 12:53:18 2021 +0800

    [Frontend] Add Span filling for frontends to Relay (#9723)
    
    * [Frontend] Add Span filling for frontends to Relay
    
    * Add a common span filling feature for tf1/2, tflite and pytorch.
    * Add test case for Span filling in each frontend.
    * Expose Tuple and TupleGetItem to python end
    
    * [Frontend] Add Span filling for frontends to Relay
    
    * Fix lint errors
    * Change default string of scope_part in Pytorch
    * Reorder the span position for one to many conversion
    
    * [Frontend] Add Span filling for frontends to Relay
    
     * nit fixed
     * Add a bool flag to control print span
     * refactor pytorch get span to a birefer way
    
    * [Frontend] Add Span filling for frontends to Relay
    
    * Add one more condition for spanFller
    * Refine the format for those pytorch node without scopeName
    
    * [Frontend] Add Span filling for frontends to Relay
    
    * Fix lint
---
 python/tvm/relay/expr.py                           |  7 ++-
 python/tvm/relay/frontend/common.py                | 53 +++++++++++++++++++++
 python/tvm/relay/frontend/pytorch.py               | 19 ++++++++
 python/tvm/relay/frontend/tensorflow.py            | 17 +------
 python/tvm/relay/frontend/tensorflow2.py           | 17 +------
 python/tvm/relay/frontend/tflite.py                | 16 +++++--
 src/printer/relay_text_printer.cc                  | 23 ++++++---
 src/printer/text_printer.h                         |  2 +-
 src/relay/ir/expr.cc                               |  4 +-
 tests/python/frontend/pytorch/test_forward.py      | 47 +++++++++++++++++++
 tests/python/frontend/tensorflow/test_forward.py   | 54 ++++++++++++++++++++++
 .../frontend/tensorflow2/test_sequential_models.py | 24 +++++++++-
 tests/python/frontend/tflite/test_forward.py       | 54 ++++++++++++++++++++++
 13 files changed, 289 insertions(+), 48 deletions(-)

diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py
index 811e205..598354e 100644
--- a/python/tvm/relay/expr.py
+++ b/python/tvm/relay/expr.py
@@ -316,10 +316,13 @@ class TupleGetItem(ExprWithOp):
 
     index: int
         The index.
+
+    span: Optional[tvm.relay.Span]
+        Span that points to original source code
     """
 
-    def __init__(self, tuple_value, index):
-        self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, 
tuple_value, index)
+    def __init__(self, tuple_value, index, span=None):
+        self.__init_handle_by_constructor__(_ffi_api.TupleGetItem, 
tuple_value, index, span)
 
 
 @tvm._ffi.register_object("relay.RefCreate")
diff --git a/python/tvm/relay/frontend/common.py 
b/python/tvm/relay/frontend/common.py
index 407afc4..be3d5ae 100755
--- a/python/tvm/relay/frontend/common.py
+++ b/python/tvm/relay/frontend/common.py
@@ -25,6 +25,7 @@ from tvm.ir import IRModule
 from tvm.topi.utils import get_const_tuple
 
 from .. import expr as _expr
+from ..expr_functor import ExprMutator
 from .. import function as _function
 from .. import transform as _transform
 from .. import op as _op
@@ -954,3 +955,55 @@ def try_resolve_var_to_const(x, graph_params):
         return _op.const(value, dtype)
 
     return x
+
+
+def set_span(sym, node_name):
+    """Set up the sapn of relay expression(s) while converting OP"""
+
+    class SpanFiller(ExprMutator):
+        """SpanFiller"""
+
+        def __init__(self, node_name, suffix_str="_PART_"):
+            ExprMutator.__init__(self)
+            self.node_name = node_name
+            self.suffix_str = suffix_str
+            self.counter = 0
+            self.distance_from_leaf = -1
+
+        def _create_span(self):
+            if self.distance_from_leaf == 0:
+                return tvm.relay.Span(tvm.relay.SourceName(self.node_name), 0, 
0, 0, 0)
+            self.distance_from_leaf -= 1
+            span_str = "{}{}{}".format(self.node_name, self.suffix_str, 
str(self.counter))
+            self.counter += 1
+            return tvm.relay.Span(tvm.relay.SourceName(span_str), 0, 0, 0, 0)
+
+        def visit_call(self, call):
+            if call.span is None:
+                self.distance_from_leaf += 1
+                new_args = [self.visit(arg) for arg in call.args]
+                return _expr.Call(
+                    call.op, new_args, call.attrs, call.type_args, 
self._create_span()
+                )
+            return call
+
+        def visit_tuple(self, tup):
+            if tup.span is None:
+                self.distance_from_leaf += 1
+                return _expr.Tuple([self.visit(field) for field in 
tup.fields], self._create_span())
+            return tup
+
+        def visit_tuple_getitem(self, op):
+            if op.span is None:
+                self.distance_from_leaf += 1
+                return _expr.TupleGetItem(self.visit(op.tuple_value), 
op.index, self._create_span())
+            return op
+
+        def fill(self, sym):
+            if isinstance(sym, _expr.TupleWrapper):
+                return _expr.TupleWrapper(self.visit(sym.tuple_value), 
sym.size)
+            if isinstance(sym, _expr.RelayExpr):
+                return self.visit(sym)
+            return sym
+
+    return SpanFiller(node_name).fill(sym)
diff --git a/python/tvm/relay/frontend/pytorch.py 
b/python/tvm/relay/frontend/pytorch.py
index 24ccad5..6e8ad68 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -45,6 +45,7 @@ from .common import infer_shape as _infer_shape
 from .common import infer_value as _infer_value
 from .common import infer_value_simulated as _infer_value_simulated
 from .common import lstm_cell, try_infer_value, unbind
+from .common import set_span
 from .pytorch_utils import is_version_greater_than
 
 __all__ = ["from_pytorch"]
@@ -3271,6 +3272,9 @@ class PyTorchOpConverter:
 
     def convert_operators(self, operators, outputs, ret_names):
         """Convert each Torch IR operators to Relay equivalent"""
+        # an op node might not belong to any of scope in trace info natively
+        # use a cunter to prevent from messing up its scope in span
+        empty_counter = 0
         for node_name, op_node in operators:
             operator = op_node.kind()
             inputs = _get_op_inputs(op_node, outputs)
@@ -3308,6 +3312,9 @@ class PyTorchOpConverter:
                 relay_out = relay_op(
                     inputs, _get_input_types(op_node, outputs, 
default_dtype=self.default_dtype)
                 )
+                span_str, empty_counter = self._get_torch_span(op_node, 
empty_counter)
+                relay_out = set_span(relay_out, span_str)
+
                 self.record_output_type(relay_out)
 
                 if isinstance(relay_out, tuple):
@@ -3321,6 +3328,18 @@ class PyTorchOpConverter:
 
         return [_wrap_const(outputs[ret_name]) for ret_name in ret_names]
 
+    def _get_torch_span(self, node, empty_counter):
+        # torch span looks like
+        # %input.5 : Float(...) = aten::relu_(%input.3), scope: __module.relu 
# ${torch}/nn file
+        # the scope part might not exist
+        if node.scopeName():
+            scope_name_str = "jit._trace.TopLevelTracedModule: " + 
node.scopeName()
+        else:
+            scope_name_str = "warning: no trace info " + str(empty_counter)
+            empty_counter += 1
+        span_str = "C.graph: {}, {}".format(node.kind(), scope_name_str)
+        return span_str, empty_counter
+
 
 def _pytorch_result_type(dtypes, non_tensor_inputs):
     """This promotes TVM dtypes like PyTorch would"""
diff --git a/python/tvm/relay/frontend/tensorflow.py 
b/python/tvm/relay/frontend/tensorflow.py
index d35e0e1..c2aa5a1 100644
--- a/python/tvm/relay/frontend/tensorflow.py
+++ b/python/tvm/relay/frontend/tensorflow.py
@@ -37,6 +37,7 @@ from .common import get_relay_op
 from .common import infer_type as _infer_type
 from .common import infer_shape as _infer_shape
 from .common import infer_value as _infer_value
+from .common import set_span
 
 from .tensorflow_ops import _convert_map
 from .tensorflow_ops import _need_prelude_for_shape_inference
@@ -1028,24 +1029,10 @@ class GraphProto(object):
         else:
             raise NotImplementedError("Operator {} not 
implemented.".format(op_name))
 
-        sym = self._set_span(sym, node_name)
+        sym = set_span(sym, node_name)
 
         return sym
 
-    @staticmethod
-    def _set_span(sym, node_name):
-        span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
-        if isinstance(sym, _expr.Call) and sym.span is None:
-            sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
-        elif isinstance(sym, _expr.TupleWrapper):
-            tuple_value = sym.tuple_value
-            if isinstance(tuple_value, _expr.Call) and tuple_value.span is 
None:
-                tuple_value = _expr.Call(
-                    tuple_value.op, tuple_value.args, tuple_value.attrs, 
tuple_value.type_args, span
-                )
-                sym = _expr.TupleWrapper(tuple_value, sym.size)
-        return sym
-
     def _licm_construct(self, loop_name, node_name):
         """Construct a node by considering whether it is
         loop invariant with the given while loop. If yes, we
diff --git a/python/tvm/relay/frontend/tensorflow2.py 
b/python/tvm/relay/frontend/tensorflow2.py
index 465f530..2c8b7d4 100644
--- a/python/tvm/relay/frontend/tensorflow2.py
+++ b/python/tvm/relay/frontend/tensorflow2.py
@@ -36,6 +36,7 @@ from .. import analysis
 from .. import function as _function
 from ..loops import while_loop as _while_loop
 from .common import infer_type as _infer_type
+from .common import set_span
 
 from .tensorflow_ops import _convert_map as _convert_map_common
 from .tensorflow_ops import _get_more_static_shape_rank
@@ -58,22 +59,6 @@ def _infer_type_with_prelude(val, prelude):
     return body.checked_type
 
 
-def set_span(sym, node_name):
-    """set span of symbol"""
-
-    span = tvm.relay.Span(tvm.relay.SourceName(node_name), 0, 0, 0, 0)
-    if isinstance(sym, _expr.Call):
-        sym = _expr.Call(sym.op, sym.args, sym.attrs, sym.type_args, span)
-    elif isinstance(sym, _expr.TupleWrapper):
-        tuple_value = sym.tuple_value
-        if isinstance(tuple_value, _expr.Call):
-            tuple_value = _expr.Call(
-                tuple_value.op, tuple_value.args, tuple_value.attrs, 
tuple_value.type_args, span
-            )
-            sym = _expr.TupleWrapper(tuple_value, sym.size)
-    return sym
-
-
 def is_tensor_list_constuctor(tf_node):
     """Check whether is tensor list constructor node."""
     return tf_node.op == "TensorListReserve"
diff --git a/python/tvm/relay/frontend/tflite.py 
b/python/tvm/relay/frontend/tflite.py
index f0f20e1..b0b2bd3 100644
--- a/python/tvm/relay/frontend/tflite.py
+++ b/python/tvm/relay/frontend/tflite.py
@@ -32,6 +32,7 @@ from .. import op as _op
 from .. import qnn as _qnn
 from .common import ExprTable
 from .common import infer_shape as _infer_shape
+from .common import set_span
 from .common import to_int_list
 from .tflite_flexbuffer import FlexBufferDecoder
 
@@ -239,12 +240,17 @@ class OperatorConverter(object):
 
             if len(output_tensors) == 1:
                 tensor_idx = output_tensors[0].tensor_idx
-                self.exp_tab.set_expr(get_tensor_name(self.subgraph, 
tensor_idx), ret)
+                curr_output = get_tensor_name(self.subgraph, tensor_idx)
+                ret = set_span(ret, "location: {}, output_name: 
{}".format(op_idx, curr_output))
+                self.exp_tab.set_expr(curr_output, ret)
             else:
-                for idx, output_tensor in enumerate(output_tensors):
-                    self.exp_tab.set_expr(
-                        get_tensor_name(self.subgraph, 
output_tensor.tensor_idx), ret[idx]
-                    )
+                out_names = []
+                for output_tensor in output_tensors:
+                    out_names.append(get_tensor_name(self.subgraph, 
output_tensor.tensor_idx))
+                curr_output = ", ".join(out_names)
+                ret = set_span(ret, "location: {}, output_name: 
{}".format(op_idx, curr_output))
+                for idx, out_name in enumerate(out_names):
+                    self.exp_tab.set_expr(out_name, ret[idx])
 
     def get_op_code_str(self, op):
         """Get TFLite ops string representation"""
diff --git a/src/printer/relay_text_printer.cc 
b/src/printer/relay_text_printer.cc
index fdc6c37..7654ef1 100644
--- a/src/printer/relay_text_printer.cc
+++ b/src/printer/relay_text_printer.cc
@@ -389,12 +389,21 @@ Doc RelayTextPrinter::VisitExpr_(const TupleNode* op) {
   if (op->fields.size() == 1) {
     doc << ",";
   }
-  return doc << ")";
+  doc << ")";
+  if (op->span.defined()) {
+    doc << " /* " << PrintSpan(op->span) << " */";
+  }
+  return doc;
 }
 
 Doc RelayTextPrinter::VisitExpr_(const TupleGetItemNode* op) {
   Doc doc;
-  return doc << Print(op->tuple) << "." << op->index;
+  doc << Print(op->tuple) << "." << op->index;
+
+  if (op->span.defined()) {
+    doc << " /* " << PrintSpan(op->span) << " */";
+  }
+  return doc;
 }
 
 Doc RelayTextPrinter::VisitExpr_(const IfNode* op) {
@@ -968,11 +977,13 @@ Doc RelayTextPrinter::PrintMapAsAttributeValue(const 
Map<ObjectRef, ObjectRef>&
   return doc;
 }
 
-Doc RelayTextPrinter::PrintSpan(const Span& span) {
+Doc RelayTextPrinter::PrintSpan(const Span& span, bool include_spans) {
   Doc doc;
-  const auto* span_node = span.as<SpanNode>();
-  ICHECK(span_node);
-  doc << span_node->source_name->name;
+  if (include_spans) {
+    const auto* span_node = span.as<SpanNode>();
+    ICHECK(span_node);
+    doc << span_node->source_name->name;
+  }
   return doc;
 }
 
diff --git a/src/printer/text_printer.h b/src/printer/text_printer.h
index a4d0ff3..ca46700 100644
--- a/src/printer/text_printer.h
+++ b/src/printer/text_printer.h
@@ -113,7 +113,7 @@ class RelayTextPrinter : public ExprFunctor<Doc(const 
Expr&)>,
    */
   Doc PrintMapAsAttributeValue(const Map<ObjectRef, ObjectRef>& map);
 
-  Doc PrintSpan(const Span& span);
+  Doc PrintSpan(const Span& span, bool include_spans = true);
 
   Doc Print(const ObjectRef& node, bool meta = false, bool try_inline = false);
 
diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc
index b680a49..f8cb4f0 100644
--- a/src/relay/ir/expr.cc
+++ b/src/relay/ir/expr.cc
@@ -362,8 +362,8 @@ TupleGetItem WithFields(TupleGetItem tuple_get_item, 
Optional<Expr> opt_tuple,
 
 TVM_REGISTER_NODE_TYPE(TupleGetItemNode);
 
-TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int 
index) {
-  return TupleGetItem(tuple, index);
+TVM_REGISTER_GLOBAL("relay.ir.TupleGetItem").set_body_typed([](Expr tuple, int 
index, Span span) {
+  return TupleGetItem(tuple, index, span);
 });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
diff --git a/tests/python/frontend/pytorch/test_forward.py 
b/tests/python/frontend/pytorch/test_forward.py
index 86970bf..a64fa0b 100644
--- a/tests/python/frontend/pytorch/test_forward.py
+++ b/tests/python/frontend/pytorch/test_forward.py
@@ -247,6 +247,53 @@ def verify_model(
     torch.cuda.empty_cache()
 
 
+def verify_span(model_name, input_data=[], custom_convert_map={}):
+    if isinstance(model_name, str):
+        baseline_model, baseline_input = load_model(model_name)
+    elif isinstance(input_data, list):
+        baseline_model = model_name
+        baseline_input = input_data
+    elif isinstance(input_data, torch.Tensor) or len(input_data.shape) == 0:
+        baseline_model = model_name
+        baseline_input = [input_data]
+    else:
+        assert False, "Unexpected input format"
+
+    trace = torch.jit.trace(baseline_model, [input.clone() for input in 
baseline_input])
+    if isinstance(baseline_model, torch.nn.Module):
+        trace = trace.float().eval()
+
+        if torch.cuda.is_available():
+            trace = trace.cuda()
+        else:
+            trace = trace.cpu()
+
+    input_names = ["input{}".format(idx) for idx, inp in 
enumerate(baseline_input)]
+    input_shapes = list(zip(input_names, [inp.shape for inp in 
baseline_input]))
+    mod, params = relay.frontend.from_pytorch(trace, input_shapes, 
custom_convert_map)
+
+    # collect fail cases for the convenience of further improvement
+    fail_cases = []
+    mod_main_start = False
+    for line in str(mod.__str__).split("\n"):
+        if "@main" in line:
+            mod_main_start = True
+            continue
+
+        if mod_main_start == True:
+            if "}" == line:
+                break
+            elif not ("/*" in line and "*/" in line):
+                fail_cases.append(line)
+
+    print(fail_cases)
+    assert len(fail_cases) == 0
+
+
+def test_span():
+    verify_span("resnet18")
+
+
 # Single operator tests
 @tvm.testing.uses_gpu
 def test_forward_pixel_shuffle():
diff --git a/tests/python/frontend/tensorflow/test_forward.py 
b/tests/python/frontend/tensorflow/test_forward.py
index 338d219..be32ca3 100644
--- a/tests/python/frontend/tensorflow/test_forward.py
+++ b/tests/python/frontend/tensorflow/test_forward.py
@@ -298,6 +298,60 @@ def is_gpu_available():
         return False
 
 
+def verify_span(mod):
+    # collect fail cases for the convenience of further improvement
+    fail_cases = []
+    mod_main_start = False
+    for line in str(mod.__str__).split("\n"):
+        if "@main" in line:
+            mod_main_start = True
+            continue
+
+        if mod_main_start == True:
+            if "}" == line:
+                break
+            elif not ("/*" in line and "*/" in line):
+                fail_cases.append(line)
+
+    print(fail_cases)
+    assert len(fail_cases) == 0
+
+
+def simple_model():
+    input_node = tf.placeholder(shape=[None, None, 3, 1], dtype=np.float32, 
name="input")
+
+    shape = tf.shape(input_node)
+    stack = tf.stack([shape[0], 3, 3], axis=0)
+    output_node = tf.reshape(input_node, stack, name="output")
+    return output_node
+
+
+#######################################################################
+# Span fill up
+# -------
+def test_span_complement_simple_model():
+    with tf.Graph().as_default() as graph:
+        model_graph = simple_model()
+        graph_def = graph.as_graph_def()
+
+        graph_def = tf_testing.ProcessGraphDefParam(graph_def)
+
+        mod, params = relay.frontend.from_tensorflow(graph_def, 
shape={"input:0", (1, 3, 3, 1)})
+        verify_span(mod)
+
+
+def test_span_complement_big_model():
+    with tf.Graph().as_default() as graph:
+        graph_def = 
tf_testing.get_workload("ResnetV2/resnet-20180601_resnet_v2_imagenet-shapes.pb")
+        # Call the utility to import the graph definition into default graph.
+        graph_def = tf_testing.ProcessGraphDefParam(graph_def)
+
+        mod, params = relay.frontend.from_tensorflow(
+            graph_def, shape={"input_tensor:0", (128, 224, 224, 3)}
+        )
+        verify_span(mod)
+
+
 #######################################################################
 # Pooling
 # -------
diff --git a/tests/python/frontend/tensorflow2/test_sequential_models.py 
b/tests/python/frontend/tensorflow2/test_sequential_models.py
index 1b5a634..b76b4a7 100644
--- a/tests/python/frontend/tensorflow2/test_sequential_models.py
+++ b/tests/python/frontend/tensorflow2/test_sequential_models.py
@@ -26,6 +26,25 @@ from tensorflow.python.framework.convert_to_constants import 
convert_variables_t
 
 from common import compare_tf_tvm
 from common import run_tf_code
+from tvm.relay.frontend.tensorflow2 import from_tensorflow
+
+
+def verify_span(mod):
+    fail_cases = []
+    mod_main_start = False
+    for line in str(mod.__str__).split("\n"):
+        if "@main" in line:
+            mod_main_start = True
+            continue
+
+        if mod_main_start == True:
+            if "}" == line:
+                break
+            elif not ("/*" in line and "*/" in line):
+                fail_cases.append(line)
+
+    print(fail_cases)
+    assert len(fail_cases) == 0
 
 
 def run_sequential_model(model_fn, input_shape):
@@ -48,7 +67,10 @@ def run_sequential_model(model_fn, input_shape):
         gdef = f.graph.as_graph_def(add_shapes=True)
         return gdef, _input, _output
 
-    compare_tf_tvm(*model_graph(model_fn, input_shape), runtime="vm")
+    gdef, _input, _output = model_graph(model_fn, input_shape)
+    mod, _ = from_tensorflow(gdef)
+    compare_tf_tvm(gdef, _input, _output, runtime="vm")
+    verify_span(mod)
 
 
 def test_dense_model():
diff --git a/tests/python/frontend/tflite/test_forward.py 
b/tests/python/frontend/tflite/test_forward.py
index 545315a..d234cd1 100644
--- a/tests/python/frontend/tflite/test_forward.py
+++ b/tests/python/frontend/tflite/test_forward.py
@@ -259,6 +259,59 @@ def run_tflite_graph(tflite_model_buf, input_data):
     return tflite_output
 
 
+def run_span_verification(
+    tflite_model_buf,
+    input_data,
+    input_node,
+    num_output=1,
+    target="llvm",
+    out_names=None,
+    mode="graph_executor",
+):
+    """Generic function to compile on relay and execute on tvm"""
+    # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1
+    try:
+        import tflite.Model
+
+        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
+    except AttributeError:
+        import tflite
+
+        tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
+    except ImportError:
+        raise ImportError("The tflite package must be installed")
+
+    input_data = convert_to_list(input_data)
+    input_node = convert_to_list(input_node)
+
+    shape_dict = {}
+    dtype_dict = {}
+    for i, e in enumerate(input_node):
+        shape_dict[e] = input_data[i].shape
+        dtype_dict[e] = input_data[i].dtype.name
+
+    mod, _ = relay.frontend.from_tflite(tflite_model, shape_dict=shape_dict, 
dtype_dict=dtype_dict)
+    verify_span(mod)
+
+
+def verify_span(mod):
+    fail_cases = []
+    mod_main_start = False
+    for line in str(mod.__str__).split("\n"):
+        if "@main" in line:
+            mod_main_start = True
+            continue
+
+        if mod_main_start == True:
+            if "}" == line:
+                break
+            elif not ("/*" in line and "*/" in line):
+                fail_cases.append(line)
+
+    print(fail_cases)
+    assert len(fail_cases) == 0
+
+
 def compare_tflite_with_tvm(
     in_data,
     in_name,
@@ -4507,6 +4560,7 @@ def test_forward_tflite2_qnn_resnet50():
         tflite_output = run_tflite_graph(tflite_model_buf, data)
         tflite_predictions = np.squeeze(tflite_output)
         tflite_sorted_labels = tflite_predictions.argsort()[-3:][::-1]
+        run_span_verification(tflite_model_buf, np.array(data), "input_1")
         tvm_output = run_tvm_graph(tflite_model_buf, np.array(data), "input_1")
         tvm_predictions = np.squeeze(tvm_output)
         tvm_sorted_labels = tvm_predictions.argsort()[-3:][::-1]

Reply via email to