This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9b5a7a457f [IR] Provide well-formed intermediate in 
ApplyPassToFunction (#16843)
9b5a7a457f is described below

commit 9b5a7a457fc967bc38155abc1a71431603c76009
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Fri Apr 5 13:21:52 2024 -0500

    [IR] Provide well-formed intermediate in ApplyPassToFunction (#16843)
    
    Prior to this commit, `ApplyPassToFunction` removed functions from the
    `IRModule` to hide them from the inner `ir.transform.Pass`.  The
    dangling `GlobalVar` references to those functions meant that the
    intermediate `IRModule` was ill-formed This commit updates the
    `ApplyPassToFunction` utility to instead replace the functions with
    `ExternFunc` nodes.  This still prevents the inner `ir.transform.Pass`
    from having visibility into functions that should not be mutated, but
    provides a well-formed `IRModule`.
---
 src/ir/apply_pass_to_function.cc                   | 136 +++++++++++++++++++++
 src/ir/transform.cc                                |  32 +----
 .../relax/test_transform_dead_code_elimination.py  |   4 -
 3 files changed, 137 insertions(+), 35 deletions(-)

diff --git a/src/ir/apply_pass_to_function.cc b/src/ir/apply_pass_to_function.cc
new file mode 100644
index 0000000000..7f7bc7e90a
--- /dev/null
+++ b/src/ir/apply_pass_to_function.cc
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/ir/apply_pass_to_function.cc
+ * \brief Utility transformation that applies an inner pass to a subset of an 
IRModule
+ */
+#include <tvm/ir/transform.h>
+#include <tvm/relax/expr.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/function.h>
+
+#include <unordered_set>
+
+#include "../runtime/regex.h"
+
+namespace tvm {
+namespace transform {
+
+namespace {
+BaseFunc BaseFuncWithAttr(BaseFunc func, const std::string& attr_key, 
ObjectRef attr_value) {
+  if (auto tir = func.as<tir::PrimFunc>()) {
+    return WithAttr(tir.value(), attr_key, attr_value);
+  } else if (auto relax = func.as<relax::Function>()) {
+    return WithAttr(relax.value(), attr_key, attr_value);
+  } else {
+    return func;
+  }
+}
+
+BaseFunc BaseFuncWithoutAttr(BaseFunc func, const std::string& attr_key) {
+  if (auto tir = func.as<tir::PrimFunc>()) {
+    return WithoutAttr(tir.value(), attr_key);
+  } else if (auto relax = func.as<relax::Function>()) {
+    return WithoutAttr(relax.value(), attr_key);
+  } else {
+    return func;
+  }
+}
+}  // namespace
+
+Pass ApplyPassToFunction(Pass pass, String func_name_regex,
+                         bool error_if_no_function_matches_regex) {
+  auto pass_name =
+      static_cast<const std::stringstream&>(std::stringstream() << 
"ApplyPassTo" << func_name_regex)
+          .str();
+
+  auto pass_func = [pass, func_name_regex, error_if_no_function_matches_regex](
+                       IRModule mod, PassContext) -> IRModule {
+    bool at_least_one_function_matched_regex = false;
+    std::unordered_set<String> keep_original_version;
+    std::unordered_set<String> internal_functions;
+    IRModule subset;
+
+    for (auto [gvar, func] : mod->functions) {
+      std::string name = gvar->name_hint;
+      if (tvm::runtime::regex_match(name, func_name_regex)) {
+        at_least_one_function_matched_regex = true;
+        if (!func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined()) {
+          // Function may be mutated, but is an internal function.  Mark
+          // it as externally-exposed, so that any call-tracing internal
+          // transforms do not remove this function, in case it its
+          // callers are not being mutated.
+
+          internal_functions.insert(gvar->name_hint);
+          func = BaseFuncWithAttr(func, tvm::attr::kGlobalSymbol, 
gvar->name_hint);
+        }
+      } else {
+        // Function may not be mutated.  Replace it with a
+        // `relax::ExternFunc` to prevent references to it from
+        // dangling.
+        keep_original_version.insert(gvar->name_hint);
+        func = relax::ExternFunc("dummy_" + name);
+        func->struct_info_ = gvar->struct_info_;
+        func->checked_type_ = gvar->checked_type_;
+      }
+
+      subset->Add(gvar, func);
+    }
+
+    if (error_if_no_function_matches_regex) {
+      CHECK(at_least_one_function_matched_regex)
+          << "No function matched regex '" << func_name_regex << "', out of 
functions " << [&]() {
+               Array<String> function_names;
+               for (const auto& [gvar, func] : mod->functions) {
+                 function_names.push_back(gvar->name_hint);
+               }
+               return function_names;
+             }();
+    }
+
+    IRModule new_subset = pass(subset);
+    if (new_subset.same_as(subset)) {
+      return mod;
+    }
+
+    auto write_ptr = mod.CopyOnWrite();
+    for (auto [gvar, func] : new_subset->functions) {
+      if (!keep_original_version.count(gvar->name_hint)) {
+        if (auto it = write_ptr->global_var_map_.find(gvar->name_hint);
+            it != write_ptr->global_var_map_.end()) {
+          write_ptr->Remove((*it).second);
+        }
+        if (internal_functions.count(gvar->name_hint)) {
+          func = BaseFuncWithoutAttr(func, tvm::attr::kGlobalSymbol);
+        }
+        write_ptr->Add(gvar, func);
+      }
+    }
+
+    return mod;
+  };
+
+  return CreateModulePass(pass_func, 0, pass_name, {});
+}
+
+TVM_REGISTER_GLOBAL("transform.ApplyPassToFunction").set_body_typed(ApplyPassToFunction);
+
+}  // namespace transform
+}  // namespace tvm
diff --git a/src/ir/transform.cc b/src/ir/transform.cc
index 3eb64fec84..dc67822411 100644
--- a/src/ir/transform.cc
+++ b/src/ir/transform.cc
@@ -25,6 +25,7 @@
 #include <tvm/ir/transform.h>
 #include <tvm/node/repr_printer.h>
 #include <tvm/node/structural_hash.h>
+#include <tvm/relax/expr.h>
 #include <tvm/relax/tuning_api.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/runtime/registry.h>
@@ -532,37 +533,6 @@ Pass CreateModulePass(const 
runtime::TypedPackedFunc<IRModule(IRModule, PassCont
   return ModulePass(pass_func, pass_info);
 }
 
-Pass ApplyPassToFunction(Pass pass, String func_name_regex,
-                         bool error_if_no_function_matches_regex) {
-  auto pass_name =
-      static_cast<const std::stringstream&>(std::stringstream() << 
"ApplyPassTo" << func_name_regex)
-          .str();
-
-  auto pass_func = [pass, func_name_regex](IRModule mod, PassContext) -> 
IRModule {
-    IRModule subset;
-
-    for (const auto& [gvar, func] : mod->functions) {
-      std::string name = gvar->name_hint;
-      if (tvm::runtime::regex_match(name, func_name_regex)) {
-        subset->Add(gvar, func);
-      }
-    }
-
-    if (subset->functions.size()) {
-      IRModule new_subset = pass(subset);
-      if (!new_subset.same_as(subset)) {
-        mod.CopyOnWrite()->Update(new_subset);
-      }
-    }
-
-    return mod;
-  };
-
-  return CreateModulePass(pass_func, 0, pass_name, {});
-}
-
-TVM_REGISTER_GLOBAL("transform.ApplyPassToFunction").set_body_typed(ApplyPassToFunction);
-
 TVM_REGISTER_NODE_TYPE(PassInfoNode);
 
 TVM_REGISTER_GLOBAL("transform.PassInfo")
diff --git a/tests/python/relax/test_transform_dead_code_elimination.py 
b/tests/python/relax/test_transform_dead_code_elimination.py
index 2dae252cad..0cb0d46247 100644
--- a/tests/python/relax/test_transform_dead_code_elimination.py
+++ b/tests/python/relax/test_transform_dead_code_elimination.py
@@ -509,8 +509,6 @@ def test_extern_func():
     verify(before, before)
 
 
-@pytest.mark.skip_well_formed_check_before_transform
-@pytest.mark.skip_well_formed_check_after_transform
 def test_compatibility_with_apply_pass_to_function():
     """DeadCodeElimination can be used with ApplyPassToFunction
 
@@ -590,8 +588,6 @@ def test_compatibility_with_apply_pass_to_function():
     tvm.ir.assert_structural_equal(Expected, After)
 
 
-@pytest.mark.skip_well_formed_check_before_transform
-@pytest.mark.skip_well_formed_check_after_transform
 def test_well_formed_output_with_restricted_scope():
     """DeadCodeElimination can be used with ApplyPassToFunction
 

Reply via email to