lhutton1 commented on code in PR #13046:
URL: https://github.com/apache/tvm/pull/13046#discussion_r994516368


##########
tests/python/relay/aot/test_cpp_aot.py:
##########
@@ -203,5 +203,44 @@ def test_pass_wrong_device_arg():
     # TODO write asserts for # and type of device.
 
 
+@pytest.mark.parametrize("target_kind", ["c", "llvm"])
+@pytest.mark.parametrize("input_name", ["input:0", "input@0", "input_0"])
+def test_aot_input_name_with_special_character(target_kind: str, input_name: 
str):
+    """Test name transforms in AOT for input names with special characters."""
+    dtype = "float32"
+    input_1 = relay.var(input_name, shape=(10, 5), dtype=dtype)
+    weight = relay.var("weight", shape=(1, 5), dtype=dtype)
+    output = relay.add(input_1, weight)
+    func = relay.Function([input_1, weight], output)
+
+    input_data = np.random.rand(10, 5).astype(dtype)
+    weight_data = np.random.rand(1, 5).astype(dtype)
+    expected_output = input_data + weight_data
+    params = {"weight": weight_data}
+
+    with tvm.transform.PassContext(opt_level=3, 
config={"tir.disable_vectorize": True}):
+        mod = tvm.relay.build(
+            tvm.IRModule.from_expr(func),
+            target=target_kind,
+            params=params,
+            executor=tvm.relay.backend.Executor("aot", {"interface-api": 
"packed"}),
+        )
+    temp_dir = tvm.contrib.utils.TempDirectory()
+    test_so_path = temp_dir / "test.so"
+    mod.export_library(test_so_path, cc="c++", options=["-std=gnu++17", "-g3", 
"-O0"])
+    # test both original name and transformed name
+    for name in ["input_0", input_name]:
+        loaded_mod = tvm.runtime.load_module(test_so_path)
+        runner = 
tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0)))
+        inputs = {name: input_data}
+        runner.set_input(**inputs)
+
+        input_ind = runner.get_input_index(name)
+        assert (runner.get_input(input_ind).asnumpy() == input_data).all

Review Comment:
   Are the brackets from `all()` missing?



##########
tests/python/relay/aot/test_cpp_aot.py:
##########
@@ -203,5 +203,44 @@ def test_pass_wrong_device_arg():
     # TODO write asserts for # and type of device.
 
 
+@pytest.mark.parametrize("target_kind", ["c", "llvm"])
+@pytest.mark.parametrize("input_name", ["input:0", "input@0", "input_0"])
+def test_aot_input_name_with_special_character(target_kind: str, input_name: 
str):
+    """Test name transforms in AOT for input names with special characters."""
+    dtype = "float32"
+    input_1 = relay.var(input_name, shape=(10, 5), dtype=dtype)
+    weight = relay.var("weight", shape=(1, 5), dtype=dtype)
+    output = relay.add(input_1, weight)
+    func = relay.Function([input_1, weight], output)
+
+    input_data = np.random.rand(10, 5).astype(dtype)
+    weight_data = np.random.rand(1, 5).astype(dtype)
+    expected_output = input_data + weight_data
+    params = {"weight": weight_data}
+
+    with tvm.transform.PassContext(opt_level=3, 
config={"tir.disable_vectorize": True}):
+        mod = tvm.relay.build(
+            tvm.IRModule.from_expr(func),
+            target=target_kind,
+            params=params,
+            executor=tvm.relay.backend.Executor("aot", {"interface-api": 
"packed"}),
+        )
+    temp_dir = tvm.contrib.utils.TempDirectory()
+    test_so_path = temp_dir / "test.so"
+    mod.export_library(test_so_path, cc="c++", options=["-std=gnu++17", "-g3", 
"-O0"])
+    # test both original name and transformed name
+    for name in ["input_0", input_name]:

Review Comment:
   This `for` seems redundant



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to