This is an automated email from the ASF dual-hosted git repository. areusch pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 06a0d63 Sanitize names of input tensors in interface header (#8720) 06a0d63 is described below commit 06a0d63c43e251cfc9fa9e81bfad5aa45219652e Author: Grant Watson <grant.wat...@arm.com> AuthorDate: Fri Sep 3 17:18:10 2021 +0100 Sanitize names of input tensors in interface header (#8720) * Sanitize names of input tensors in interface header Change-Id: I7f02a993887bf84316262cd2586a734a9079c338 * Update tensor name sanitizer tests to parameterize them. Change-Id: I157d8d8d607de2904285e403893f146e97b510d5 * Only test unpacked, C interface API, AOT case Change-Id: I9082ae32079a1a3924c06c7f26c757aafa46dec2 --- python/tvm/micro/interface_api.py | 13 +++++++- src/target/source/source_module.cc | 6 +++- tests/python/relay/aot/aot_test_utils.py | 17 +++++++--- tests/python/relay/aot/test_crt_aot.py | 56 ++++++++++++++++++++++++++++++++ 4 files changed, 86 insertions(+), 6 deletions(-) diff --git a/python/tvm/micro/interface_api.py b/python/tvm/micro/interface_api.py index 8086b1e..d9961e9 100644 --- a/python/tvm/micro/interface_api.py +++ b/python/tvm/micro/interface_api.py @@ -17,7 +17,13 @@ """Defines functions for generating a C interface header""" +# TODO: Currently the Interface API header is generated in Python but the source it references +# is generated in C++. These should be consolidated to generate both header and source in C++ +# and avoid re-implementing logic, such as name sanitising, in the two different languages. +# See https://github.com/apache/tvm/issues/8792 . + import os +import re from tvm.relay.backend.utils import mangle_module_name @@ -58,8 +64,13 @@ def generate_c_interface_header(module_name, inputs, outputs, output_path): _emit_brief(header_file, module_name, "Input tensor pointers") header_file.write(f"struct {mangled_name}_inputs {{\n") + sanitized_names = [] for input_name in inputs: - header_file.write(f" void* {input_name};\n") + sanitized_input_name = re.sub(r"\W", "_", input_name) + if sanitized_input_name in sanitized_names: + raise ValueError(f"Sanitized input tensor name clash: {sanitized_input_name}") + sanitized_names.append(sanitized_input_name) + header_file.write(f" void* {sanitized_input_name};\n") header_file.write("};\n\n") _emit_brief(header_file, module_name, "Output tensor pointers") diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 7728773..9b93b07 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -234,6 +234,8 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << "}\n"; } + static int isNotAlnum(char c) { return !std::isalnum(c); } + void GenerateCInterfaceEntrypoint(const std::string& entrypoint_name, const std::string& run_func, const std::string& mod_name) { code_ << "#include <" << mod_name << ".h>\n"; @@ -252,7 +254,9 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { << ") {"; code_ << "return " << run_func << "("; for (const auto& input : metadata_->inputs) { - code_ << "inputs->" << input << ","; + std::string sanitised_input = input; + std::replace_if(sanitised_input.begin(), sanitised_input.end(), isNotAlnum, '_'); + code_ << "inputs->" << sanitised_input << ","; } if (metadata_->num_outputs == 1) { code_ << "outputs->output"; diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index e5ac85b..baa2397 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -22,6 +22,7 @@ import logging import os import pathlib import platform +import re import shutil import subprocess import tarfile @@ -250,7 +251,10 @@ int main(){\n def emit_main_data(main_file, input_map, output_list, mod_name): for key in input_map: - main_file.write(f'#include "{mangle_name(mod_name,"input_data")}_{key}.h"\n') + sanitized_tensor_name = re.sub(r"\W", "_", key) + main_file.write( + f'#include "{mangle_name(mod_name,"input_data")}_{sanitized_tensor_name}.h"\n' + ) for i in range(0, len(output_list)): main_file.write(f'#include "{mangle_name(mod_name,"expected_output_data")}{i}.h"\n') @@ -262,7 +266,10 @@ def emit_main_data_structs(main_file, input_map, output_list, mod_name): f"struct {mangle_name(mod_name, 'inputs')} {mangle_name(mod_name, 'inputs')} = {{" ) for key in input_map: - main_file.write(f"\t.{key} = {mangle_name(mod_name, 'input_data')}_{key},\n") + sanitized_tensor_name = re.sub(r"\W", "_", key) + main_file.write( + f"\t.{sanitized_tensor_name} = {mangle_name(mod_name, 'input_data')}_{sanitized_tensor_name},\n" + ) main_file.write("};\n") main_file.write( @@ -283,7 +290,8 @@ def emit_main_data_setup(main_file, input_map, output_list, mod_name): main_file.write(f'void* {mangle_name(mod_name,"inputs")}[{num_inputs}] = {{ ') for key in input_map: - main_file.write(f'{mangle_name(mod_name,"input_data")}_{key}, ') + sanitized_tensor_name = re.sub(r"\W", "_", key) + main_file.write(f'{mangle_name(mod_name,"input_data")}_{sanitized_tensor_name}, ') main_file.write("};\n") main_file.write(f'void* {mangle_name(mod_name,"outputs")}[{num_outputs}] = {{ ') @@ -521,8 +529,9 @@ def compile_and_run( workspace_bytes += extract_main_workspace_size_bytes(base_path) for key in model.inputs: + sanitized_tensor_name = re.sub(r"\W", "_", key) create_header_file( - f'{mangle_name(model.name, "input_data")}_{key}', + f'{mangle_name(model.name, "input_data")}_{sanitized_tensor_name}', model.inputs[key], include_path, ) diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 36cffef..64000a9 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -503,5 +503,61 @@ def test_transpose(interface_api, use_unpacked_api, test_runner): ) +def test_name_sanitiser(): + """Test that input tensors with special characters in the name don't break compilation""" + + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_DEFAULT_RUNNER + + func = relay.var("input-x::2", "float32") + ident = relay.Function([func], func) + one = np.array(1.0, "float32") + inputs = {"input-x::2": one} + output_list = generate_ref_data(ident, inputs) + + compile_and_run( + AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + enable_op_fusion=False, + ) + + +def test_name_sanitiser_name_clash(): + """Test that 2 input tensors with names that clash once sanitized, generates an error""" + + interface_api = "c" + use_unpacked_api = True + test_runner = AOT_DEFAULT_RUNNER + + dtype = "float32" + x = relay.var("input::-1", shape=(10, 5), dtype=dtype) + # Next 2 input tensor names will clash once sanitized. + y = relay.var("input::-2", shape=(10, 5), dtype=dtype) + t = relay.var("input:--2", shape=(), dtype=dtype) + a = relay.add(x, y) + b = relay.transpose(a) + z = relay.add(b, t) + # Check result. + func = relay.Function([x, y, t], z) + x_data = np.random.rand(10, 5).astype(dtype) + y_data = np.random.rand(10, 5).astype(dtype) + t_data = np.random.uniform(size=()).astype(dtype) + + inputs = {"input::-1": x_data, "input::-2": y_data, "input:--2": t_data} + output_list = generate_ref_data(func, inputs) + + with pytest.raises(ValueError, match="Sanitized input tensor name clash"): + compile_and_run( + AOTTestModel(module=IRModule.from_expr(func), inputs=inputs, outputs=output_list), + test_runner, + interface_api, + use_unpacked_api, + enable_op_fusion=False, + ) + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))