This is an automated email from the ASF dual-hosted git repository. lukhut pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 7afac14ebd [BugFix][MSC] split name_string with index by colon from the right (#17000) 7afac14ebd is described below commit 7afac14ebd0f22a5a53c51d362a5bc853fb1c868 Author: Peng Sun <peng....@arm.com> AuthorDate: Wed May 29 09:41:33 2024 +0100 [BugFix][MSC] split name_string with index by colon from the right (#17000) Fixes a naming mismatch in MSCGraph where tensor_name could formatted as 'string:index:index',and the corresponding node.name is 'string:index'. Splitting tensor_name from the right aligns it correctly. For example, the TFLite default input name 'serving_default_input:0' becomes 'serving_default_input:0:0' in MSCGraph, while node.name remains 'serving_default_input:0'. --- src/contrib/msc/core/utils.h | 2 +- .../contrib/test_msc/test_translate_relay.py | 36 ++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h index 5762c96352..6c39a8d0a1 100644 --- a/src/contrib/msc/core/utils.h +++ b/src/contrib/msc/core/utils.h @@ -142,7 +142,7 @@ class StringUtils { */ TVM_DLL static const std::tuple<String, String> SplitOnce(const String& src_string, const String& sep, - bool from_left = true); + bool from_left = false); /*! * \brief Get the tokens between left and right. diff --git a/tests/python/contrib/test_msc/test_translate_relay.py b/tests/python/contrib/test_msc/test_translate_relay.py index 39a45035a5..6c47b8b395 100644 --- a/tests/python/contrib/test_msc/test_translate_relay.py +++ b/tests/python/contrib/test_msc/test_translate_relay.py @@ -27,8 +27,11 @@ from torch.nn import Module import tvm.testing from tvm.relax.frontend.torch import from_fx from tvm.relay.frontend import from_pytorch +from tvm import relay +from tvm.ir.module import IRModule from tvm.contrib.msc.core.frontend import translate from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen +from tvm.contrib.msc.core import utils as msc_utils def _valid_target(target): @@ -1057,5 +1060,38 @@ def test_max(): verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")]) +def test_name_string_with_colon(): + """test name string with colons, + e.g., TFLite default input name 'serving_default_input:0' + """ + + dtype = "float32" + x_var = relay.var("input_0:0", shape=(3, 5), dtype=dtype) + y_var = relay.var("input_1:0", shape=(3, 5), dtype=dtype) + z_add = relay.add(x_var, y_var) + func = relay.Function([x_var, y_var], z_add) + mod = IRModule() + mod["main"] = func + + try: + graph, _ = translate.from_relay(mod) + except Exception as err: + raise RuntimeError(f"Translation from relay to graph failed: {err}") + inspect = graph.inspect() + + expected = { + "inputs": [ + {"name": "input_0:0", "shape": [3, 5], "dtype": dtype, "layout": ""}, + {"name": "input_1:0", "shape": [3, 5], "dtype": dtype, "layout": ""}, + ], + "outputs": [{"name": "add", "shape": [3, 5], "dtype": dtype, "layout": ""}], + "nodes": {"total": 3, "input": 2, "add": 1}, + } + + assert msc_utils.dict_equal(inspect, expected), "Inspect {} mismatch with expected {}".format( + inspect, expected + ) + + if __name__ == "__main__": tvm.testing.main()