This is an automated email from the ASF dual-hosted git repository. jcf94 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new a468f08 Add metadata information to the listing of PassContext configuration listing function (#8226) a468f08 is described below commit a468f08d77ae8bc0dfd492cf5adfcafd026090aa Author: Leandro Nunes <leandro.nu...@arm.com> AuthorDate: Thu Jun 10 04:14:05 2021 +0100 Add metadata information to the listing of PassContext configuration listing function (#8226) * Rename PassContext::ListConfigNames() to PassContext::ListConfigs() and its Python counterpart tvm.ir.transform.PassContext.list_config_names -> list_configs() * Adjust PassContext::ListConfigs() to include also metadata (currently only including the data type) * Adjust unit tests --- include/tvm/ir/transform.h | 6 +++--- python/tvm/ir/transform.py | 12 +++++++++--- src/ir/transform.cc | 16 +++++++++------- tests/cpp/relay_transform_sequential_test.cc | 9 +++++---- tests/python/relay/test_pass_instrument.py | 7 ++++--- 5 files changed, 30 insertions(+), 20 deletions(-) diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h index d5b50a7..cb556fc 100644 --- a/include/tvm/ir/transform.h +++ b/include/tvm/ir/transform.h @@ -184,10 +184,10 @@ class PassContext : public ObjectRef { TVM_DLL static PassContext Current(); /*! - * \brief Get all supported configuration names, registered within the PassContext. - * \return List of all configuration names. + * \brief Get all supported configuration names and metadata, registered within the PassContext. + * \return Map indexed by the config name, pointing to the metadata map as key-value */ - TVM_DLL static Array<String> ListConfigNames(); + TVM_DLL static Map<String, Map<String, String>> ListConfigs(); /*! * \brief Call instrument implementations' callbacks when entering PassContext. diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 7a0ea82..9296244 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -121,9 +121,15 @@ class PassContext(tvm.runtime.Object): return _ffi_transform_api.GetCurrentPassContext() @staticmethod - def list_config_names(): - """List all registered `PassContext` configuration names""" - return list(_ffi_transform_api.ListConfigNames()) + def list_configs(): + """List all registered `PassContext` configuration names and metadata. + + Returns + ------- + configs : Dict[str, Dict[str, str]] + + """ + return _ffi_transform_api.ListConfigs() @tvm._ffi.register_object("transform.Pass") diff --git a/src/ir/transform.cc b/src/ir/transform.cc index a8541b1..8120ca7 100644 --- a/src/ir/transform.cc +++ b/src/ir/transform.cc @@ -145,12 +145,14 @@ class PassConfigManager { } } - Array<String> ListConfigNames() { - Array<String> config_keys; + Map<String, Map<String, String>> ListConfigs() { + Map<String, Map<String, String>> configs; for (const auto& kv : key2vtype_) { - config_keys.push_back(kv.first); + Map<String, String> metadata; + metadata.Set("type", kv.second.type_key); + configs.Set(kv.first, metadata); } - return config_keys; + return configs; } static PassConfigManager* Global() { @@ -171,8 +173,8 @@ void PassContext::RegisterConfigOption(const char* key, uint32_t value_type_inde PassConfigManager::Global()->Register(key, value_type_index); } -Array<String> PassContext::ListConfigNames() { - return PassConfigManager::Global()->ListConfigNames(); +Map<String, Map<String, String>> PassContext::ListConfigs() { + return PassConfigManager::Global()->ListConfigs(); } PassContext PassContext::Create() { return PassContext(make_object<PassContextNode>()); } @@ -619,7 +621,7 @@ Pass PrintIR(String header, bool show_meta_data) { TVM_REGISTER_GLOBAL("transform.PrintIR").set_body_typed(PrintIR); -TVM_REGISTER_GLOBAL("transform.ListConfigNames").set_body_typed(PassContext::ListConfigNames); +TVM_REGISTER_GLOBAL("transform.ListConfigs").set_body_typed(PassContext::ListConfigs); } // namespace transform } // namespace tvm diff --git a/tests/cpp/relay_transform_sequential_test.cc b/tests/cpp/relay_transform_sequential_test.cc index 16e9438..6d38e10 100644 --- a/tests/cpp/relay_transform_sequential_test.cc +++ b/tests/cpp/relay_transform_sequential_test.cc @@ -121,11 +121,12 @@ TEST(Relay, Sequential) { ICHECK(tvm::StructuralEqual()(f, expected)); } -TEST(PassContextListConfigNames, Basic) { - Array<String> configs = relay::transform::PassContext::ListConfigNames(); +TEST(PassContextListConfigs, Basic) { + Map<String, Map<String, String>> configs = relay::transform::PassContext::ListConfigs(); ICHECK_EQ(configs.empty(), false); - ICHECK_EQ(std::count(std::begin(configs), std::end(configs), "relay.backend.use_auto_scheduler"), - 1); + + auto config = configs["relay.backend.use_auto_scheduler"]; + ICHECK_EQ(config["type"], "IntImm"); } int main(int argc, char** argv) { diff --git a/tests/python/relay/test_pass_instrument.py b/tests/python/relay/test_pass_instrument.py index c7405ae..610d4e4 100644 --- a/tests/python/relay/test_pass_instrument.py +++ b/tests/python/relay/test_pass_instrument.py @@ -183,10 +183,11 @@ def test_instrument_pass_counts(): def test_list_pass_configs(): - config_names = tvm.transform.PassContext.list_config_names() + configs = tvm.transform.PassContext.list_configs() - assert len(config_names) > 0 - assert "relay.backend.use_auto_scheduler" in config_names + assert len(configs) > 0 + assert "relay.backend.use_auto_scheduler" in configs.keys() + assert configs["relay.backend.use_auto_scheduler"]["type"] == "IntImm" def test_enter_pass_ctx_exception():