gemini-code-assist[bot] commented on code in PR #18645: URL: https://github.com/apache/tvm/pull/18645#discussion_r2670938404
########## tests/python/ir/test_ir_transform_resolve_dependency.py: ########## @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Tests for pass dependency resolution in Sequential passes. + +Note: ResolveDependency is a C++ function that needs to be exposed to Python +for direct testing. Currently, we test the behavior indirectly through +Sequential pass execution. +""" + +import tvm +import tvm.testing +from tvm.ir import transform +from tvm.ir.transform import PassContext +from tvm.ir.module import IRModule + + +def create_test_pass(name, required=None, opt_level=0): + """Helper function to create a test pass with specified dependencies.""" + + @transform.module_pass(opt_level=opt_level, name=name, required=required or [], traceable=False) + def pass_func(mod, ctx): + # Simple pass that just returns the module unchanged + return mod + + return pass_func + + +def test_sequential_with_dependencies(): + """Test that Sequential correctly handles pass dependencies during execution.""" + + # Create passes without dependencies to test basic execution + # The dependency resolution is tested at the C++ level through compilation + pass1 = create_test_pass("Pass1", required=[]) + pass2 = create_test_pass("Pass2", required=[]) + + # Create a sequential pass + seq = transform.Sequential([pass1, pass2]) + + # Create a simple IRModule for testing + mod = IRModule({}) + + # Execute the sequential pass + with PassContext(opt_level=3): + result = seq(mod) + + # Verify that the passes were executed + assert result is not None + assert isinstance(result, IRModule) + + +def test_sequential_opt_level_filtering(): + """Test that Sequential filters passes based on opt_level.""" + + pass1 = create_test_pass("Pass1", required=[], opt_level=1) + pass2 = create_test_pass("Pass2", required=[], opt_level=2) + pass3 = create_test_pass("Pass3", required=[], opt_level=3) + + seq = transform.Sequential([pass1, pass2, pass3]) + mod = IRModule({}) + + # With opt_level=2, pass3 (opt_level=3) should be skipped + with PassContext(opt_level=2): + result = seq(mod) + + # Execution should succeed even with some passes filtered + assert result is not None + + +def test_sequential_required_pass_execution(): + """Test that required passes are executed even if not in the list.""" + + # Create a pass that depends on PrintIR (a standard TVM pass) + # PrintIR requires a header string parameter + print_ir_pass = transform.PrintIR("TestHeader") + pass1 = create_test_pass("Pass1", required=[]) + + # Create sequential with both passes - pass1 should execute after print_ir + seq = transform.Sequential([pass1, print_ir_pass]) + mod = IRModule({}) + + # Execute - both passes should execute + with PassContext(opt_level=3): + result = seq(mod) + + assert result is not None + + +if __name__ == "__main__": + tvm.testing.main() Review Comment:  This new test file is a good start, but the current tests are quite basic and don't seem to exercise the new dependency resolution logic in `ResolveDependency`. For instance, they don't cover cases with transitive dependencies or circular dependencies. It would be great to add more comprehensive tests to ensure the topological sort and dependency handling are working as expected. Here are some scenarios that would be valuable to test: * A simple dependency chain (e.g., A requires B, B requires C). * A pass required by multiple other passes. * A required pass that is not in the initial `Sequential` list and needs to be fetched from the registry. * A circular dependency, to verify that it's detected and handled gracefully with a warning. * A pass that is disabled via `PassContext`. To test the order of execution, you could create test passes that append their names to a list when they run, and then assert that the list is in the expected topologically sorted order. ########## src/ir/transform.cc: ########## @@ -463,6 +459,150 @@ Pass GetPass(const ffi::String& pass_name) { return (*f)().cast<Pass>(); } +void SequentialNode::ResolveDependency(const IRModule& mod) { + // Get the current pass context to check which passes are enabled + // Note: mod parameter is reserved for future use when dependency resolution + // might need to consider module-specific information + (void)mod; // Suppress unused parameter warning + PassContext pass_ctx = PassContext::Current(); + + // Step 1: Collect all enabled passes from the current list + std::unordered_map<std::string, Pass> name_to_pass; + std::vector<Pass> enabled_passes; + + for (const Pass& pass : passes) { + if (!pass.defined()) { + continue; + } + const PassInfo& pass_info = pass->Info(); + if (pass_ctx.PassEnabled(pass_info)) { + std::string pass_name = pass_info->name; + // Avoid duplicates + if (name_to_pass.find(pass_name) == name_to_pass.end()) { + name_to_pass[pass_name] = pass; + enabled_passes.push_back(pass); + } + } + } + + // Step 2: Collect all required passes that are not in the current list + // We need to do this in multiple passes to handle transitive dependencies + std::unordered_set<std::string> processed_required; + bool changed = true; + while (changed) { + changed = false; + for (size_t i = 0; i < enabled_passes.size(); ++i) { + const PassInfo& pass_info = enabled_passes[i]->Info(); + for (const auto& required_name : pass_info->required) { + std::string req_name = required_name; + std::string key = pass_info->name + "->" + req_name; + if (processed_required.find(key) != processed_required.end()) { + continue; + } + processed_required.insert(key); + + // Check if the required pass is already in our list + if (name_to_pass.find(req_name) == name_to_pass.end()) { + // Try to get it from the global registry + try { + Pass required_pass = GetPass(ffi::String(req_name)); + const PassInfo& req_pass_info = required_pass->Info(); + if (pass_ctx.PassEnabled(req_pass_info)) { + name_to_pass[req_name] = required_pass; + enabled_passes.push_back(required_pass); + changed = true; + } + } catch (...) { + // If we can't get the pass, we'll skip this dependency + // It will be resolved at runtime in operator() + VLOG(0) << "Warning: Cannot resolve required pass '" << req_name + << "' for pass '" << pass_info->name + << "'. It will be resolved at runtime if needed."; + } Review Comment:  Using `catch (...)` is generally discouraged as it can catch and hide unexpected exceptions, making debugging harder. Since `GetPass` is expected to throw a `dmlc::Error` if the pass is not found, it would be safer to catch a more specific exception type like `std::exception`. This would also allow logging more details about the error. ```suggestion } catch (const std::exception& e) { // If we can't get the pass, we'll skip this dependency // It will be resolved at runtime in operator() VLOG(0) << "Warning: Cannot resolve required pass '" << req_name << "' for pass '" << pass_info->name << "'. It will be resolved at runtime if needed. Details: " << e.what(); } ``` ########## src/ir/transform.cc: ########## @@ -463,6 +459,150 @@ Pass GetPass(const ffi::String& pass_name) { return (*f)().cast<Pass>(); } +void SequentialNode::ResolveDependency(const IRModule& mod) { Review Comment:  This is a great implementation of pass dependency resolution using topological sort. However, it seems this function is not called from `SequentialNode::operator()`, which still contains the old, simpler dependency resolution logic. As a result, the improvements made here won't be active when a `Sequential` pass is run. Was the intention to call this function from `SequentialNode::operator()`? If so, that operator would need to be updated. The current implementation of `ResolveDependency` also modifies `this->passes`, which might be an issue if called from a `const` method like `operator()`. It might be better for `ResolveDependency` to return the sorted list of passes instead of modifying the member variable in-place. ########## src/ir/transform.cc: ########## @@ -463,6 +459,150 @@ Pass GetPass(const ffi::String& pass_name) { return (*f)().cast<Pass>(); } +void SequentialNode::ResolveDependency(const IRModule& mod) { + // Get the current pass context to check which passes are enabled + // Note: mod parameter is reserved for future use when dependency resolution + // might need to consider module-specific information + (void)mod; // Suppress unused parameter warning + PassContext pass_ctx = PassContext::Current(); + + // Step 1: Collect all enabled passes from the current list + std::unordered_map<std::string, Pass> name_to_pass; + std::vector<Pass> enabled_passes; + + for (const Pass& pass : passes) { + if (!pass.defined()) { + continue; + } + const PassInfo& pass_info = pass->Info(); + if (pass_ctx.PassEnabled(pass_info)) { + std::string pass_name = pass_info->name; + // Avoid duplicates + if (name_to_pass.find(pass_name) == name_to_pass.end()) { + name_to_pass[pass_name] = pass; + enabled_passes.push_back(pass); + } + } + } + + // Step 2: Collect all required passes that are not in the current list + // We need to do this in multiple passes to handle transitive dependencies + std::unordered_set<std::string> processed_required; + bool changed = true; + while (changed) { + changed = false; + for (size_t i = 0; i < enabled_passes.size(); ++i) { + const PassInfo& pass_info = enabled_passes[i]->Info(); + for (const auto& required_name : pass_info->required) { + std::string req_name = required_name; + std::string key = pass_info->name + "->" + req_name; + if (processed_required.find(key) != processed_required.end()) { + continue; + } + processed_required.insert(key); + + // Check if the required pass is already in our list + if (name_to_pass.find(req_name) == name_to_pass.end()) { + // Try to get it from the global registry + try { + Pass required_pass = GetPass(ffi::String(req_name)); + const PassInfo& req_pass_info = required_pass->Info(); + if (pass_ctx.PassEnabled(req_pass_info)) { + name_to_pass[req_name] = required_pass; + enabled_passes.push_back(required_pass); + changed = true; + } + } catch (...) { + // If we can't get the pass, we'll skip this dependency + // It will be resolved at runtime in operator() + VLOG(0) << "Warning: Cannot resolve required pass '" << req_name + << "' for pass '" << pass_info->name + << "'. It will be resolved at runtime if needed."; + } + } + } + } + } + + // Step 3: Build dependency graph + // Map from pass name to its index in enabled_passes + std::unordered_map<std::string, size_t> name_to_index; + for (size_t i = 0; i < enabled_passes.size(); ++i) { + const PassInfo& pass_info = enabled_passes[i]->Info(); + name_to_index[pass_info->name] = i; + } + + // Build reverse adjacency list: dependents[i] contains indices of passes that depend on pass i + // This is used for topological sort + std::vector<std::vector<size_t>> dependents(enabled_passes.size()); + std::vector<size_t> in_degree(enabled_passes.size(), 0); + + for (size_t i = 0; i < enabled_passes.size(); ++i) { + const PassInfo& pass_info = enabled_passes[i]->Info(); + for (const auto& required_name : pass_info->required) { + std::string req_name = required_name; + auto it = name_to_index.find(req_name); + if (it != name_to_index.end()) { + // The required pass is in our enabled passes list + // pass i depends on pass req_idx, so req_idx should come before i + size_t req_idx = it->second; + dependents[req_idx].push_back(i); + in_degree[i]++; + } + // If the required pass is not in our list, it will be handled at runtime + } + } + + // Step 4: Topological sort using Kahn's algorithm + std::queue<size_t> queue; + for (size_t i = 0; i < enabled_passes.size(); ++i) { + if (in_degree[i] == 0) { + queue.push(i); + } + } + + std::vector<Pass> sorted_passes; + std::unordered_set<size_t> visited; Review Comment:  In Kahn's algorithm for topological sorting, a node is added to the queue only when its in-degree becomes zero. This happens exactly once for each node in a Directed Acyclic Graph (DAG). Therefore, this `visited` set and its usages on lines 571-574 appear to be redundant. You can simplify the code by removing them. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
