This is an automated email from the ASF dual-hosted git repository. lunderberg pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 9b5a7a457f [IR] Provide well-formed intermediate in ApplyPassToFunction (#16843) 9b5a7a457f is described below commit 9b5a7a457fc967bc38155abc1a71431603c76009 Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Fri Apr 5 13:21:52 2024 -0500 [IR] Provide well-formed intermediate in ApplyPassToFunction (#16843) Prior to this commit, `ApplyPassToFunction` removed functions from the `IRModule` to hide them from the inner `ir.transform.Pass`. The dangling `GlobalVar` references to those functions meant that the intermediate `IRModule` was ill-formed This commit updates the `ApplyPassToFunction` utility to instead replace the functions with `ExternFunc` nodes. This still prevents the inner `ir.transform.Pass` from having visibility into functions that should not be mutated, but provides a well-formed `IRModule`. --- src/ir/apply_pass_to_function.cc | 136 +++++++++++++++++++++ src/ir/transform.cc | 32 +---- .../relax/test_transform_dead_code_elimination.py | 4 - 3 files changed, 137 insertions(+), 35 deletions(-) diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc new file mode 100644 index 0000000000..7f7bc7e90a --- /dev/null +++ b/src/ir/apply_pass_to_function.cc @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/apply_pass_to_function.cc + * \brief Utility transformation that applies an inner pass to a subset of an IRModule + */ +#include <tvm/ir/transform.h> +#include <tvm/relax/expr.h> +#include <tvm/runtime/registry.h> +#include <tvm/tir/function.h> + +#include <unordered_set> + +#include "../runtime/regex.h" + +namespace tvm { +namespace transform { + +namespace { +BaseFunc BaseFuncWithAttr(BaseFunc func, const std::string& attr_key, ObjectRef attr_value) { + if (auto tir = func.as<tir::PrimFunc>()) { + return WithAttr(tir.value(), attr_key, attr_value); + } else if (auto relax = func.as<relax::Function>()) { + return WithAttr(relax.value(), attr_key, attr_value); + } else { + return func; + } +} + +BaseFunc BaseFuncWithoutAttr(BaseFunc func, const std::string& attr_key) { + if (auto tir = func.as<tir::PrimFunc>()) { + return WithoutAttr(tir.value(), attr_key); + } else if (auto relax = func.as<relax::Function>()) { + return WithoutAttr(relax.value(), attr_key); + } else { + return func; + } +} +} // namespace + +Pass ApplyPassToFunction(Pass pass, String func_name_regex, + bool error_if_no_function_matches_regex) { + auto pass_name = + static_cast<const std::stringstream&>(std::stringstream() << "ApplyPassTo" << func_name_regex) + .str(); + + auto pass_func = [pass, func_name_regex, error_if_no_function_matches_regex]( + IRModule mod, PassContext) -> IRModule { + bool at_least_one_function_matched_regex = false; + std::unordered_set<String> keep_original_version; + std::unordered_set<String> internal_functions; + IRModule subset; + + for (auto [gvar, func] : mod->functions) { + std::string name = gvar->name_hint; + if (tvm::runtime::regex_match(name, func_name_regex)) { + at_least_one_function_matched_regex = true; + if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined()) { + // Function may be mutated, but is an internal function. Mark + // it as externally-exposed, so that any call-tracing internal + // transforms do not remove this function, in case it its + // callers are not being mutated. + + internal_functions.insert(gvar->name_hint); + func = BaseFuncWithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint); + } + } else { + // Function may not be mutated. Replace it with a + // `relax::ExternFunc` to prevent references to it from + // dangling. + keep_original_version.insert(gvar->name_hint); + func = relax::ExternFunc("dummy_" + name); + func->struct_info_ = gvar->struct_info_; + func->checked_type_ = gvar->checked_type_; + } + + subset->Add(gvar, func); + } + + if (error_if_no_function_matches_regex) { + CHECK(at_least_one_function_matched_regex) + << "No function matched regex '" << func_name_regex << "', out of functions " << [&]() { + Array<String> function_names; + for (const auto& [gvar, func] : mod->functions) { + function_names.push_back(gvar->name_hint); + } + return function_names; + }(); + } + + IRModule new_subset = pass(subset); + if (new_subset.same_as(subset)) { + return mod; + } + + auto write_ptr = mod.CopyOnWrite(); + for (auto [gvar, func] : new_subset->functions) { + if (!keep_original_version.count(gvar->name_hint)) { + if (auto it = write_ptr->global_var_map_.find(gvar->name_hint); + it != write_ptr->global_var_map_.end()) { + write_ptr->Remove((*it).second); + } + if (internal_functions.count(gvar->name_hint)) { + func = BaseFuncWithoutAttr(func, tvm::attr::kGlobalSymbol); + } + write_ptr->Add(gvar, func); + } + } + + return mod; + }; + + return CreateModulePass(pass_func, 0, pass_name, {}); +} + +TVM_REGISTER_GLOBAL("transform.ApplyPassToFunction").set_body_typed(ApplyPassToFunction); + +} // namespace transform +} // namespace tvm diff --git a/src/ir/transform.cc b/src/ir/transform.cc index 3eb64fec84..dc67822411 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -25,6 +25,7 @@ #include <tvm/ir/transform.h> #include <tvm/node/repr_printer.h> #include <tvm/node/structural_hash.h> +#include <tvm/relax/expr.h> #include <tvm/relax/tuning_api.h> #include <tvm/runtime/device_api.h> #include <tvm/runtime/registry.h> @@ -532,37 +533,6 @@ Pass CreateModulePass(const runtime::TypedPackedFunc<IRModule(IRModule, PassCont return ModulePass(pass_func, pass_info); } -Pass ApplyPassToFunction(Pass pass, String func_name_regex, - bool error_if_no_function_matches_regex) { - auto pass_name = - static_cast<const std::stringstream&>(std::stringstream() << "ApplyPassTo" << func_name_regex) - .str(); - - auto pass_func = [pass, func_name_regex](IRModule mod, PassContext) -> IRModule { - IRModule subset; - - for (const auto& [gvar, func] : mod->functions) { - std::string name = gvar->name_hint; - if (tvm::runtime::regex_match(name, func_name_regex)) { - subset->Add(gvar, func); - } - } - - if (subset->functions.size()) { - IRModule new_subset = pass(subset); - if (!new_subset.same_as(subset)) { - mod.CopyOnWrite()->Update(new_subset); - } - } - - return mod; - }; - - return CreateModulePass(pass_func, 0, pass_name, {}); -} - -TVM_REGISTER_GLOBAL("transform.ApplyPassToFunction").set_body_typed(ApplyPassToFunction); - TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_GLOBAL("transform.PassInfo") diff --git a/tests/python/relax/test_transform_dead_code_elimination.py b/tests/python/relax/test_transform_dead_code_elimination.py index 2dae252cad..0cb0d46247 100644 --- a/tests/python/relax/test_transform_dead_code_elimination.py +++ b/tests/python/relax/test_transform_dead_code_elimination.py @@ -509,8 +509,6 @@ def test_extern_func(): verify(before, before) -@pytest.mark.skip_well_formed_check_before_transform -@pytest.mark.skip_well_formed_check_after_transform def test_compatibility_with_apply_pass_to_function(): """DeadCodeElimination can be used with ApplyPassToFunction @@ -590,8 +588,6 @@ def test_compatibility_with_apply_pass_to_function(): tvm.ir.assert_structural_equal(Expected, After) -@pytest.mark.skip_well_formed_check_before_transform -@pytest.mark.skip_well_formed_check_after_transform def test_well_formed_output_with_restricted_scope(): """DeadCodeElimination can be used with ApplyPassToFunction