This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7afac14ebd [BugFix][MSC] split name_string with index by colon from 
the right (#17000)
7afac14ebd is described below

commit 7afac14ebd0f22a5a53c51d362a5bc853fb1c868
Author: Peng Sun <peng....@arm.com>
AuthorDate: Wed May 29 09:41:33 2024 +0100

    [BugFix][MSC] split name_string with index by colon from the right (#17000)
    
    Fixes a naming mismatch in MSCGraph where tensor_name could formatted as
    'string:index:index',and the corresponding node.name is 'string:index'.
    Splitting tensor_name from the right aligns it correctly.
    
    For example, the TFLite default input name 'serving_default_input:0'
    becomes 'serving_default_input:0:0' in MSCGraph, while node.name remains
    'serving_default_input:0'.
---
 src/contrib/msc/core/utils.h                       |  2 +-
 .../contrib/test_msc/test_translate_relay.py       | 36 ++++++++++++++++++++++
 2 files changed, 37 insertions(+), 1 deletion(-)

diff --git a/src/contrib/msc/core/utils.h b/src/contrib/msc/core/utils.h
index 5762c96352..6c39a8d0a1 100644
--- a/src/contrib/msc/core/utils.h
+++ b/src/contrib/msc/core/utils.h
@@ -142,7 +142,7 @@ class StringUtils {
    */
   TVM_DLL static const std::tuple<String, String> SplitOnce(const String& 
src_string,
                                                             const String& sep,
-                                                            bool from_left = 
true);
+                                                            bool from_left = 
false);
 
   /*!
    * \brief Get the tokens between left and right.
diff --git a/tests/python/contrib/test_msc/test_translate_relay.py 
b/tests/python/contrib/test_msc/test_translate_relay.py
index 39a45035a5..6c47b8b395 100644
--- a/tests/python/contrib/test_msc/test_translate_relay.py
+++ b/tests/python/contrib/test_msc/test_translate_relay.py
@@ -27,8 +27,11 @@ from torch.nn import Module
 import tvm.testing
 from tvm.relax.frontend.torch import from_fx
 from tvm.relay.frontend import from_pytorch
+from tvm import relay
+from tvm.ir.module import IRModule
 from tvm.contrib.msc.core.frontend import translate
 from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen
+from tvm.contrib.msc.core import utils as msc_utils
 
 
 def _valid_target(target):
@@ -1057,5 +1060,38 @@ def test_max():
     verify_model(Max(), [([256, 256], "float32"), ([256, 256], "float32")])
 
 
+def test_name_string_with_colon():
+    """test name string with colons,
+    e.g., TFLite default input name 'serving_default_input:0'
+    """
+
+    dtype = "float32"
+    x_var = relay.var("input_0:0", shape=(3, 5), dtype=dtype)
+    y_var = relay.var("input_1:0", shape=(3, 5), dtype=dtype)
+    z_add = relay.add(x_var, y_var)
+    func = relay.Function([x_var, y_var], z_add)
+    mod = IRModule()
+    mod["main"] = func
+
+    try:
+        graph, _ = translate.from_relay(mod)
+    except Exception as err:
+        raise RuntimeError(f"Translation from relay to graph failed: {err}")
+    inspect = graph.inspect()
+
+    expected = {
+        "inputs": [
+            {"name": "input_0:0", "shape": [3, 5], "dtype": dtype, "layout": 
""},
+            {"name": "input_1:0", "shape": [3, 5], "dtype": dtype, "layout": 
""},
+        ],
+        "outputs": [{"name": "add", "shape": [3, 5], "dtype": dtype, "layout": 
""}],
+        "nodes": {"total": 3, "input": 2, "add": 1},
+    }
+
+    assert msc_utils.dict_equal(inspect, expected), "Inspect {} mismatch with 
expected {}".format(
+        inspect, expected
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to