This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit d4b99761b28ce80e04f97a74b74c3834d13fd6d2 Author: tqchen <[email protected]> AuthorDate: Fri May 2 09:55:03 2025 -0400 [WIN] Windows support --- CMakeLists.txt | 7 ++- ffi/CMakeLists.txt | 2 +- ffi/cmake/Utils/Library.cmake | 9 ++-- ffi/include/tvm/ffi/base_details.h | 13 +++++ ffi/include/tvm/ffi/container/container_details.h | 9 ++++ ffi/include/tvm/ffi/container/ndarray.h | 2 +- ffi/include/tvm/ffi/container/tuple.h | 4 +- ffi/include/tvm/ffi/container/variant.h | 5 +- ffi/include/tvm/ffi/error.h | 10 +++- ffi/include/tvm/ffi/function_details.h | 3 ++ ffi/src/ffi/traceback_win.cc | 4 +- ffi/tests/cpp/CMakeLists.txt | 7 ++- ffi/tests/cpp/test_array.cc | 10 ++-- ffi/tests/cpp/test_map.cc | 2 +- ffi/tests/cpp/test_ndarray.cc | 2 +- ffi/tests/cpp/test_rvalue_ref.cc | 2 +- ffi/tests/cpp/test_variant.cc | 4 +- ffi/tests/cpp/testing_object.h | 11 ++-- include/tvm/ir/name_supply.h | 1 + include/tvm/meta_schedule/search_strategy.h | 64 +++++++++++------------ include/tvm/meta_schedule/space_generator.h | 64 +++++++++++------------ include/tvm/meta_schedule/task_scheduler.h | 64 +++++++++++------------ include/tvm/runtime/threading_backend.h | 2 +- include/tvm/tir/stmt.h | 15 ++++-- src/ir/function.cc | 3 +- src/meta_schedule/utils.h | 3 ++ src/relax/op/tensor/sorting.cc | 1 + src/relax/transform/lower_alloc_tensor.cc | 1 + src/relax/transform/realize_vdevice.cc | 1 + src/relax/transform/tuning_api/database.cc | 20 +++---- src/runtime/thread_pool.cc | 8 +-- src/runtime/threading_backend.cc | 2 +- src/script/printer/relax/tir.cc | 1 + src/script/printer/tir/buffer.cc | 1 + src/tir/analysis/stmt_finding.cc | 1 + tests/cpp/llvm_codegen_registry_test.cc | 2 +- tests/cpp/target_test.cc | 2 +- 37 files changed, 202 insertions(+), 160 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index cf11f0281f..b45e5becf3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -159,7 +159,6 @@ if(MSVC) add_definitions(-D_SCL_SECURE_NO_WARNINGS) add_definitions(-D_ENABLE_EXTENDED_ALIGNED_STORAGE) add_definitions(-DNOMINMAX) - # regeneration does not work well with msbuild custom rules. set(CMAKE_SUPPRESS_REGENERATION ON) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc") @@ -761,9 +760,9 @@ install( # More target definitions if(MSVC) - target_compile_definitions(tvm_objs PRIVATE -DTVM_EXPORTS) - target_compile_definitions(tvm_libinfo_objs PRIVATE -DTVM_EXPORTS) - target_compile_definitions(tvm_runtime_objs PRIVATE -DTVM_EXPORTS) + target_compile_definitions(tvm_objs PRIVATE -DTVM_EXPORTS -DTVM_FFI_EXPORTS) + target_compile_definitions(tvm_libinfo_objs PRIVATE -DTVM_EXPORTS -DTVM_FFI_EXPORTS) + target_compile_definitions(tvm_runtime_objs PRIVATE -DTVM_EXPORTS -DTVM_FFI_EXPORTS) endif() set(TVM_IS_DEBUG_BUILD OFF) diff --git a/ffi/CMakeLists.txt b/ffi/CMakeLists.txt index 11e380e46f..209913d60f 100644 --- a/ffi/CMakeLists.txt +++ b/ffi/CMakeLists.txt @@ -111,7 +111,7 @@ target_link_libraries(tvm_ffi_static PUBLIC tvm_ffi_header) install(TARGETS tvm_ffi_static DESTINATION lib${LIB_SUFFIX}) install(TARGETS tvm_ffi_shared DESTINATION lib${LIB_SUFFIX}) -add_msvc_compact_defs(tvm_ffi_objs) +add_msvc_flags(tvm_ffi_objs) ########## Adding tests ########## diff --git a/ffi/cmake/Utils/Library.cmake b/ffi/cmake/Utils/Library.cmake index 6afea393d6..f391ee8fd4 100644 --- a/ffi/cmake/Utils/Library.cmake +++ b/ffi/cmake/Utils/Library.cmake @@ -28,7 +28,7 @@ function(add_dsymutil target_name) endif() endfunction() -function(add_msvc_compact_defs target_name) +function(add_msvc_flags target_name) # running if we are under msvc if(MSVC) target_compile_definitions(${target_name} PUBLIC -DWIN32_LEAN_AND_MEAN) @@ -36,6 +36,7 @@ function(add_msvc_compact_defs target_name) target_compile_definitions(${target_name} PUBLIC -D_SCL_SECURE_NO_WARNINGS) target_compile_definitions(${target_name} PUBLIC -D_ENABLE_EXTENDED_ALIGNED_STORAGE) target_compile_definitions(${target_name} PUBLIC -DNOMINMAX) + target_compile_options(${target_name} PRIVATE "/Z7") endif() endfunction() @@ -60,11 +61,7 @@ function(add_target_from_obj target_name obj_target_name) add_dependencies(${target_name} ${target_name}_static ${target_name}_shared) if (MSVC) target_compile_definitions(${obj_target_name} PRIVATE TVM_FFI_EXPORTS) - set_target_properties( - ${obj_target_name} ${target_name}_shared ${target_name}_static - PROPERTIES - MSVC_RUNTIME_LIBRARY "MultiThreaded$<$<CONFIG:Debug>:Debug>" - ) endif() add_dsymutil(${target_name}_shared) + add_msvc_flags(${target_name}_shared) endfunction() diff --git a/ffi/include/tvm/ffi/base_details.h b/ffi/include/tvm/ffi/base_details.h index cc54be77bb..18cc3ecb72 100644 --- a/ffi/include/tvm/ffi/base_details.h +++ b/ffi/include/tvm/ffi/base_details.h @@ -32,7 +32,20 @@ #include <utility> #if defined(_MSC_VER) +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif + +#ifndef NOMINMAX +#define NOMINMAX +#endif + #include <windows.h> + +#ifdef ERROR +#undef ERROR +#endif + #endif #if defined(_MSC_VER) diff --git a/ffi/include/tvm/ffi/container/container_details.h b/ffi/include/tvm/ffi/container/container_details.h index 39ebcb4a58..c842218815 100644 --- a/ffi/include/tvm/ffi/container/container_details.h +++ b/ffi/include/tvm/ffi/container/container_details.h @@ -275,6 +275,15 @@ class ReverseIterAdapter { template <typename T> inline constexpr bool storage_enabled_v = std::is_same_v<T, Any> || TypeTraits<T>::storage_enabled; +/*! + * \brief Check if all T are compatible with Any. + * + * \tparam T The type to check. + * \return True if T is compatible with Any, false otherwise. + */ +template <typename ...T> +inline constexpr bool all_storage_enabled_v = (storage_enabled_v<T> && ...); + /** * \brief Check if Any storage of Derived can always be directly used as Base. * diff --git a/ffi/include/tvm/ffi/container/ndarray.h b/ffi/include/tvm/ffi/container/ndarray.h index 065c8db2cd..aeb0b0d308 100644 --- a/ffi/include/tvm/ffi/container/ndarray.h +++ b/ffi/include/tvm/ffi/container/ndarray.h @@ -180,7 +180,7 @@ class NDArrayObjFromNDAlloc : public NDArrayObj { ExtraArgs&&... extra_args) : alloc_(alloc) { this->device = device; - this->ndim = shape.size(); + this->ndim = static_cast<int>(shape.size()); this->dtype = dtype; this->shape = const_cast<int64_t*>(shape.data()); this->strides = nullptr; diff --git a/ffi/include/tvm/ffi/container/tuple.h b/ffi/include/tvm/ffi/container/tuple.h index 8cc17eb302..1fff225aed 100644 --- a/ffi/include/tvm/ffi/container/tuple.h +++ b/ffi/include/tvm/ffi/container/tuple.h @@ -33,6 +33,7 @@ namespace tvm { namespace ffi { + /*! * \brief Typed tuple like std::tuple backed by ArrayObj container. * @@ -43,8 +44,7 @@ namespace ffi { template <typename... Types> class Tuple : public ObjectRef { public: - static constexpr bool all_storage_enabled_v = (details::storage_enabled_v<Types> && ...); - static_assert(all_storage_enabled_v, "All types used in Tuple<...> must be compatible with Any"); + static_assert(details::all_storage_enabled_v<Types...>, "All types used in Tuple<...> must be compatible with Any"); Tuple() : ObjectRef(MakeDefaultTupleNode()) {} Tuple(const Tuple<Types...>& other) : ObjectRef(other) {} diff --git a/ffi/include/tvm/ffi/container/variant.h b/ffi/include/tvm/ffi/container/variant.h index 6c34a8a15c..caf11557e6 100644 --- a/ffi/include/tvm/ffi/container/variant.h +++ b/ffi/include/tvm/ffi/container/variant.h @@ -34,6 +34,7 @@ namespace tvm { namespace ffi { + /*! * \brief A typed variant container. * @@ -42,10 +43,8 @@ namespace ffi { template <typename... V> class Variant { public: - static constexpr bool all_compatible_with_any_v = (TypeTraits<V>::storage_enabled && ...); - static_assert(all_compatible_with_any_v, + static_assert(details::all_storage_enabled_v<V...>, "All types used in Variant<...> must be compatible with Any"); - /* * \brief Helper utility to check if the type can be contained in the variant */ diff --git a/ffi/include/tvm/ffi/error.h b/ffi/include/tvm/ffi/error.h index 8c97ebc8e9..c39b0ff64e 100644 --- a/ffi/include/tvm/ffi/error.h +++ b/ffi/include/tvm/ffi/error.h @@ -134,7 +134,12 @@ class ErrorBuilder { public: explicit ErrorBuilder(std::string kind, std::string traceback, bool log_before_throw) : kind_(kind), traceback_(traceback), log_before_throw_(log_before_throw) {} - + +// MSVC disable warning in error builder as it is exepected +#ifdef _MSC_VER +#pragma disagnostic push +#pragma warning(disable : 4722) +#endif // avoid inline to reduce binary size, error throw path do not need to be fast [[noreturn]] ~ErrorBuilder() noexcept(false) { ::tvm::ffi::Error error(std::move(kind_), stream_.str(), std::move(traceback_)); @@ -143,6 +148,9 @@ class ErrorBuilder { } throw error; } +#ifdef _MSC_VER +#pragma disagnostic pop +#endif std::ostringstream& stream() { return stream_; } diff --git a/ffi/include/tvm/ffi/function_details.h b/ffi/include/tvm/ffi/function_details.h index 19425f4fda..34e166428e 100644 --- a/ffi/include/tvm/ffi/function_details.h +++ b/ffi/include/tvm/ffi/function_details.h @@ -79,8 +79,11 @@ struct FuncFunctorImpl { using RetType = R; /*! \brief total number of arguments*/ static constexpr size_t num_args = sizeof...(Args); + // MSVC is not that friendly to in-template nested bool evaluation +#ifndef _MSC_VER /*! \brief Whether this function can be converted to ffi::Function via FromUnpacked */ static constexpr bool unpacked_supported = (ArgSupported<Args> && ...) && (RetSupported<R>); +#endif static TVM_FFI_INLINE std::string Sig() { using IdxSeq = std::make_index_sequence<sizeof...(Args)>; diff --git a/ffi/src/ffi/traceback_win.cc b/ffi/src/ffi/traceback_win.cc index 786b70447a..1de4c88681 100644 --- a/ffi/src/ffi/traceback_win.cc +++ b/ffi/src/ffi/traceback_win.cc @@ -45,6 +45,7 @@ std::string Traceback() { HANDLE process = GetCurrentProcess(); HANDLE thread = GetCurrentThread(); + SymSetOptions(SYMOPT_LOAD_LINES | SYMOPT_UNDNAME); SymInitialize(process, NULL, TRUE); CONTEXT context = {}; RtlCaptureContext(&context); @@ -83,7 +84,8 @@ std::string Traceback() { const char* symbol = "<unknown>"; int lineno = 0; // Get file and line number - IMAGEHLP_LINE64 line_info = {}; + IMAGEHLP_LINE64 line_info; + ZeroMemory(&line_info, sizeof(IMAGEHLP_LINE64)); line_info.SizeOfStruct = sizeof(IMAGEHLP_LINE64); DWORD displacement32 = 0; diff --git a/ffi/tests/cpp/CMakeLists.txt b/ffi/tests/cpp/CMakeLists.txt index 3ae2df73c8..429683600b 100644 --- a/ffi/tests/cpp/CMakeLists.txt +++ b/ffi/tests/cpp/CMakeLists.txt @@ -17,7 +17,10 @@ set_target_properties( add_cxx_warning(tvm_ffi_tests) add_sanitizer_address(tvm_ffi_tests) add_dsymutil(tvm_ffi_tests) -add_msvc_compact_defs(tvm_ffi_tests) +add_msvc_flags(tvm_ffi_tests) target_link_libraries(tvm_ffi_tests PRIVATE tvm_ffi_shared) - add_googletest(tvm_ffi_tests) + +if (MSVC) + target_link_options(tvm_ffi_tests PRIVATE /DEBUG) +endif() diff --git a/ffi/tests/cpp/test_array.cc b/ffi/tests/cpp/test_array.cc index 15ed0bff06..5062f6dd2d 100644 --- a/ffi/tests/cpp/test_array.cc +++ b/ffi/tests/cpp/test_array.cc @@ -72,7 +72,7 @@ TEST(Array, Map) { // Basic functionality TInt x(1), y(1); Array<TInt> var_arr{x, y}; - Array<TNumber> expr_arr = var_arr.Map([](TInt var) -> TNumber { return TFloat(var->value + 1); }); + Array<TNumber> expr_arr = var_arr.Map([](TInt var) -> TNumber { return TFloat(static_cast<double>(var->value + 1)); }); EXPECT_NE(var_arr.get(), expr_arr.get()); EXPECT_TRUE(expr_arr[0]->IsInstance<TFloatObj>()); @@ -94,7 +94,7 @@ TEST(Array, PushPop) { ASSERT_EQ(a.front(), b.front()); ASSERT_EQ(a.back(), b.back()); ASSERT_EQ(a.size(), b.size()); - int n = a.size(); + int n = static_cast<int>(a.size()); for (int j = 0; j < n; ++j) { ASSERT_EQ(a[j], b[j]); } @@ -105,7 +105,7 @@ TEST(Array, PushPop) { ASSERT_EQ(a.size(), b.size()); a.pop_back(); b.pop_back(); - int n = a.size(); + int n = static_cast<int>(a.size()); for (int j = 0; j < n; ++j) { ASSERT_EQ(a[j], b[j]); } @@ -161,8 +161,8 @@ TEST(Array, InsertEraseRange) { static_assert(std::is_same_v<decltype(*range_a.begin()), int>); for (size_t n = 1; n <= 10; ++n) { - a.insert(a.end(), n); - b.insert(b.end(), n); + a.insert(a.end(), static_cast<int>(n)); + b.insert(b.end(), static_cast<int>(n)); for (size_t pos = 0; pos <= n; ++pos) { a.insert(a.begin() + pos, range_a.begin(), range_a.end()); b.insert(b.begin() + pos, range_b.begin(), range_b.end()); diff --git a/ffi/tests/cpp/test_map.cc b/ffi/tests/cpp/test_map.cc index c579f47442..aa449141be 100644 --- a/ffi/tests/cpp/test_map.cc +++ b/ffi/tests/cpp/test_map.cc @@ -279,7 +279,7 @@ TEST(Map, MapInsertOrder) { // test that map preserves the insertion order auto get_reverse_order = [](size_t size) { std::vector<int> reverse_order; - for (int i = size; i != 0; --i) { + for (int i = static_cast<int>(size); i != 0; --i) { reverse_order.push_back(i - 1); } return reverse_order; diff --git a/ffi/tests/cpp/test_ndarray.cc b/ffi/tests/cpp/test_ndarray.cc index 811227f073..3d7b00cd33 100644 --- a/ffi/tests/cpp/test_ndarray.cc +++ b/ffi/tests/cpp/test_ndarray.cc @@ -41,7 +41,7 @@ TEST(NDArray, Basic) { EXPECT_EQ(shape[2], 3); EXPECT_EQ(nd.dtype(), DLDataType({kDLFloat, 32, 1})); for (int64_t i = 0; i < shape.Product(); ++i) { - reinterpret_cast<float*>(nd->data)[i] = i; + reinterpret_cast<float*>(nd->data)[i] = static_cast<float>(i); } Any any0 = nd; diff --git a/ffi/tests/cpp/test_rvalue_ref.cc b/ffi/tests/cpp/test_rvalue_ref.cc index 65d86d764a..d3a82c7158 100644 --- a/ffi/tests/cpp/test_rvalue_ref.cc +++ b/ffi/tests/cpp/test_rvalue_ref.cc @@ -48,7 +48,7 @@ TEST(RValueRef, Basic) { TEST(RValueRef, ParamChecking) { // try decution - Function fadd1 = Function::FromUnpacked([](TInt a) -> int { return a->value + 1; }); + Function fadd1 = Function::FromUnpacked([](TInt a) -> int64_t { return a->value + 1; }); // convert that triggers error EXPECT_THROW( diff --git a/ffi/tests/cpp/test_variant.cc b/ffi/tests/cpp/test_variant.cc index 77e4b8eee1..65b8a1c9e6 100644 --- a/ffi/tests/cpp/test_variant.cc +++ b/ffi/tests/cpp/test_variant.cc @@ -74,7 +74,7 @@ TEST(Variant, ObjectPtrHashEqual) { TEST(Variant, FromUnpacked) { // try decution - Function fadd1 = Function::FromUnpacked([](const Variant<int, TInt>& a) -> int { + Function fadd1 = Function::FromUnpacked([](const Variant<int, TInt>& a) -> int64_t { if (auto opt_int = a.as<int>()) { return opt_int.value() + 1; } else { @@ -100,7 +100,7 @@ TEST(Variant, FromUnpacked) { }, ::tvm::ffi::Error); - Function fadd2 = Function::FromUnpacked([](const Array<Variant<int, TInt>>& a) -> int { + Function fadd2 = Function::FromUnpacked([](const Array<Variant<int, TInt>>& a) -> int64_t { if (auto opt_int = a[0].as<int>()) { return opt_int.value() + 1; } else { diff --git a/ffi/tests/cpp/testing_object.h b/ffi/tests/cpp/testing_object.h index e2ee297b05..d0db5ca094 100644 --- a/ffi/tests/cpp/testing_object.h +++ b/ffi/tests/cpp/testing_object.h @@ -63,10 +63,11 @@ class TIntObj : public TNumberObj { static constexpr const char* _type_key = "test.Int"; TVM_FFI_DECLARE_FINAL_OBJECT_INFO(TIntObj, TNumberObj); - - TVM_FFI_REFLECTION_DEF(TIntObj).def_readonly("value", &TIntObj::value); }; +TVM_FFI_REFLECTION_DEF(TIntObj).def_readonly("value", &TIntObj::value); + + class TInt : public TNumber { public: explicit TInt(int64_t value) { data_ = make_object<TIntObj>(value); } @@ -123,15 +124,15 @@ struct TypeTraits<testing::TPrimExpr> : public ObjectRefWithFallbackTraitsBase<testing::TPrimExpr, StrictBool, int64_t, double, String> { static TVM_FFI_INLINE testing::TPrimExpr ConvertFallbackValue(StrictBool value) { - return testing::TPrimExpr("bool", value); + return testing::TPrimExpr("bool", static_cast<double>(value)); } static TVM_FFI_INLINE testing::TPrimExpr ConvertFallbackValue(int64_t value) { - return testing::TPrimExpr("int64", value); + return testing::TPrimExpr("int64", static_cast<double>(value)); } static TVM_FFI_INLINE testing::TPrimExpr ConvertFallbackValue(double value) { - return testing::TPrimExpr("float32", value); + return testing::TPrimExpr("float32", static_cast<double>(value)); } // hack into the dtype to store string static TVM_FFI_INLINE testing::TPrimExpr ConvertFallbackValue(String value) { diff --git a/include/tvm/ir/name_supply.h b/include/tvm/ir/name_supply.h index 11dac3fe52..136c95741e 100644 --- a/include/tvm/ir/name_supply.h +++ b/include/tvm/ir/name_supply.h @@ -28,6 +28,7 @@ #include <string> #include <unordered_map> #include <utility> +#include <cctype> #include "tvm/ir/expr.h" diff --git a/include/tvm/meta_schedule/search_strategy.h b/include/tvm/meta_schedule/search_strategy.h index 3f44a2438d..aeef1bff30 100644 --- a/include/tvm/meta_schedule/search_strategy.h +++ b/include/tvm/meta_schedule/search_strategy.h @@ -41,38 +41,38 @@ class SearchStrategy; /*! * \brief The search strategy for measure candidates generation. * \note The relationship between SearchStrategy and other classes are as follows: - ┌──────────────────────────────────────────────────────────────┐ - ┌──┴───────────────────────────────────────────────────────────┐ │ -┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ -│ ┌─────────────────────┐ │ │ │ -│ │ │ Generate │ │ │ -│ │ Space Generator ├──────────────┐ │ │ │ -│ │ │ │ │ │ │ -│ └─────────────────────┘ ▼ │ │ │ -│ Design Space │ │ │ -│ ┌─────────────────────┐ │ │ │ │ -│ Generate │ │ Pretuning │ │ │ │ -│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ -│ │ │ │ │ ├──┘ -│ │ └─────────────────────┘ ├──┘ -└────┼─────────────────────────────────────────────────────────┘ - │ - │ -┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ -│ │ ┌───────────┐ │ -│ │ Send to │ │ Send to │ -│ ▼ ┌─────────────►│ Builder ├──────────┐ │ -│ Measure Candidate │ Builder │ │ Runner │ │ -│ │ │ └───────────┘ │ │ -│ │ ┌────────────┴────────┐ │ │ -│ │ │ │ ┌───────────┐ │ │ -│ └────►│ Task Scheduler │ │ │ │ │ -│ │ │ │ Runner │◄─────────┘ │ -│ └─────────────────────┘ │ │ │ -│ ▲ └─────┬─────┘ │ -│ │ │ │ -│ └─── Runner Future ◄────┘ │ -└─────────────────────────────────────────────────────────────────────┘ + +--------------------------------------------------------------+ + +--+-----------------------------------------------------------+ | + +--+------------------ Tune Context -----------------------------+ | | + | +---------------------+ | | | + | | | Generate | | | + | | Space Generator +--------------+ | | | + | | | | | | | + | +---------------------+ v | | | + | Design Space | | | + | +---------------------+ | | | | + | Generate | | Pretuning | | | | + | +-----------+ Search Strategy |<-------------+ | | | + | | | | | +--+ + | | +---------------------+ +--+ + +----+----------------------------------------------------------+ + | + | + +----+---------------- Managed By Task Scheduler ---------------------+ + | | +-----------+ | + | | Send to | | Send to | + | v +-------------+| Builder +----------+ | + | Measure Candidate | Builder | | Runner | | + | | | +-----------+ | | + | | +------------+------------+ | | + | | | | +-----------+ | | + | +---->| Task Scheduler | | | | | + | | | | Runner |<-----+ | + | +-------------------------+ | | | + | ^ +-----+-----+ | + | | | | + | +---- Runner Future <-------+ | + +---------------------------------------------------------------------+ */ class SearchStrategyNode : public runtime::Object { public: diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index f746eb8091..650320d1e2 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -40,38 +40,38 @@ class SpaceGenerator; /*! * \brief The abstract class for design space generation. * \note The relationship between SpaceGenerator and other classes are as follows: - ┌──────────────────────────────────────────────────────────────┐ - ┌──┴───────────────────────────────────────────────────────────┐ │ -┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ -│ ┌─────────────────────┐ │ │ │ -│ │ │ Generate │ │ │ -│ │ Space Generator ├──────────────┐ │ │ │ -│ │ │ │ │ │ │ -│ └─────────────────────┘ ▼ │ │ │ -│ Design Space │ │ │ -│ ┌─────────────────────┐ │ │ │ │ -│ Generate │ │ Pretuning │ │ │ │ -│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ -│ │ │ │ │ ├──┘ -│ │ └─────────────────────┘ ├──┘ -└────┼─────────────────────────────────────────────────────────┘ - │ - │ -┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ -│ │ ┌───────────┐ │ -│ │ Send to │ │ Send to │ -│ ▼ ┌─────────────►│ Builder ├──────────┐ │ -│ Measure Candidate │ Builder │ │ Runner │ │ -│ │ │ └───────────┘ │ │ -│ │ ┌────────────┴────────┐ │ │ -│ │ │ │ ┌───────────┐ │ │ -│ └────►│ Task Scheduler │ │ │ │ │ -│ │ │ │ Runner │◄─────────┘ │ -│ └─────────────────────┘ │ │ │ -│ ▲ └─────┬─────┘ │ -│ │ │ │ -│ └─── Runner Future ◄────┘ │ -└─────────────────────────────────────────────────────────────────────┘ + +--------------------------------------------------------------+ + +--+-----------------------------------------------------------+ | + +--+------------------ Tune Context -----------------------------+ | | + | +---------------------+ | | | + | | | Generate | | | + | | Space Generator +--------------+ | | | + | | | | | | | + | +---------------------+ v | | | + | Design Space | | | + | +---------------------+ | | | | + | Generate | | Pretuning | | | | + | +-----------+ Search Strategy |<-------------+ | | | + | | | | | +--+ + | | +---------------------+ +--+ + +----+----------------------------------------------------------+ + | + | + +----+---------------- Managed By Task Scheduler ---------------------+ + | | +-----------+ | + | | Send to | | Send to | + | v +-------------+| Builder +----------+ | + | Measure Candidate | Builder | | Runner | | + | | | +-----------+ | | + | | +------------+------------+ | | + | | | | +-----------+ | | + | +---->| Task Scheduler | | | | | + | | | | Runner |<-----+ | + | +-------------------------+ | | | + | ^ +-----+-----+ | + | | | | + | +---- Runner Future <-------+ | + +---------------------------------------------------------------------+ */ class SpaceGeneratorNode : public runtime::Object { public: diff --git a/include/tvm/meta_schedule/task_scheduler.h b/include/tvm/meta_schedule/task_scheduler.h index f4fc491286..8cc3595d68 100644 --- a/include/tvm/meta_schedule/task_scheduler.h +++ b/include/tvm/meta_schedule/task_scheduler.h @@ -92,38 +92,38 @@ class TaskRecord : public runtime::ObjectRef { /*! * \brief The abstract interface of task schedulers. * \note The relationship between SpaceGenerator and other classes are as follows: - ┌──────────────────────────────────────────────────────────────┐ - ┌──┴───────────────────────────────────────────────────────────┐ │ -┌──┴────────────────── Tune Context ───────────────────────────┐ │ │ -│ ┌─────────────────────┐ │ │ │ -│ │ │ Generate │ │ │ -│ │ Space Generator ├──────────────┐ │ │ │ -│ │ │ │ │ │ │ -│ └─────────────────────┘ ▼ │ │ │ -│ Design Space │ │ │ -│ ┌─────────────────────┐ │ │ │ │ -│ Generate │ │ Pretuning │ │ │ │ -│ ┌───────────┤ Search Strategy │◄─────────────┘ │ │ │ -│ │ │ │ │ ├──┘ -│ │ └─────────────────────┘ ├──┘ -└────┼─────────────────────────────────────────────────────────┘ - │ - │ -┌────┼──────────────── Managed By Task Scheduler ─────────────────────┐ -│ │ ┌───────────┐ │ -│ │ Send to │ │ Send to │ -│ ▼ ┌─────────────►│ Builder ├──────────┐ │ -│ Measure Candidate │ Builder │ │ Runner │ │ -│ │ │ └───────────┘ │ │ -│ │ ┌────────────┴────────┐ │ │ -│ │ │ │ ┌───────────┐ │ │ -│ └────►│ Task Scheduler │ │ │ │ │ -│ │ │ │ Runner │◄─────────┘ │ -│ └─────────────────────┘ │ │ │ -│ ▲ └─────┬─────┘ │ -│ │ │ │ -│ └─── Runner Future ◄────┘ │ -└─────────────────────────────────────────────────────────────────────┘ + +--------------------------------------------------------------+ + +--+-----------------------------------------------------------+ | + +--+------------------ Tune Context -----------------------------+ | | + | +---------------------+ | | | + | | | Generate | | | + | | Space Generator +--------------+ | | | + | | | | | | | + | +---------------------+ v | | | + | Design Space | | | + | +---------------------+ | | | | + | Generate | | Pretuning | | | | + | +-----------+ Search Strategy |<-------------+ | | | + | | | | | +--+ + | | +---------------------+ +--+ + +----+----------------------------------------------------------+ + | + | + +----+---------------- Managed By Task Scheduler ---------------------+ + | | +-----------+ | + | | Send to | | Send to | + | v +-------------+| Builder +----------+ | + | Measure Candidate | Builder | | Runner | | + | | | +-----------+ | | + | | +------------+------------+ | | + | | | | +-----------+ | | + | +---->| Task Scheduler | | | | | + | | | | Runner |<-----+ | + | +-------------------------+ | | | + | ^ +-----+-----+ | + | | | | + | +---- Runner Future <-------+ | + +---------------------------------------------------------------------+ */ class TaskSchedulerNode : public runtime::Object { public: diff --git a/include/tvm/runtime/threading_backend.h b/include/tvm/runtime/threading_backend.h index 4d09f43f95..27a8a546c8 100644 --- a/include/tvm/runtime/threading_backend.h +++ b/include/tvm/runtime/threading_backend.h @@ -116,7 +116,7 @@ class ThreadGroup { /*! * \brief Platform-agnostic no-op. */ -TVM_DLL void Yield(); +TVM_DLL void YieldThread(); /*! * \return the maximum number of effective workers for this system. */ diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 72b49df91e..efcf4a47d6 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -816,13 +816,18 @@ class SeqStmt : public Stmt { static Optional<SeqStmt> AsSeqStmt(const T& t) { if constexpr (std::is_same_v<T, SeqStmt>) { return t; - } else if constexpr (!std::is_base_of_v<T, SeqStmt>) { - return NullOpt; - } else if (auto* ptr = t.template as<SeqStmtNode>()) { - return GetRef<SeqStmt>(ptr); - } else { + } + if constexpr (!std::is_base_of_v<T, SeqStmt>) { return NullOpt; } + if constexpr (std::is_base_of_v<Stmt, T>) { + if (const SeqStmtNode* ptr = t.template as<SeqStmtNode>()) { + return GetRef<SeqStmt>(ptr); + } else { + return NullOpt; + } + } + return NullOpt; } template <typename T> diff --git a/src/ir/function.cc b/src/ir/function.cc index b8de83612b..8f543b0326 100644 --- a/src/ir/function.cc +++ b/src/ir/function.cc @@ -58,6 +58,7 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithAttrs") } } LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); + TVM_FFI_UNREACHABLE(); }); TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") @@ -69,7 +70,7 @@ TVM_REGISTER_GLOBAL("ir.BaseFuncWithoutAttr") return WithoutAttr(Downcast<relax::Function>(std::move(func)), key); } else { LOG(FATAL) << "Do not support function type " << func->GetTypeKey(); - return func; + TVM_FFI_UNREACHABLE(); } }); diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index 8b36405baf..aaebf3db7f 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -45,6 +45,7 @@ #include <algorithm> #include <string> #include <unordered_set> +#include <sstream> #include <utility> #include <vector> @@ -636,6 +637,8 @@ class BlockCollector : public tir::StmtVisitor { String func_name_; }; +void JSONFileAppendLine(const String& path, const std::string& line); +std::vector<Any> JSONFileReadLines(const String& path, int num_threads, bool allow_missing); } // namespace meta_schedule } // namespace tvm diff --git a/src/relax/op/tensor/sorting.cc b/src/relax/op/tensor/sorting.cc index c4c4c5a614..9f8545e9b3 100644 --- a/src/relax/op/tensor/sorting.cc +++ b/src/relax/op/tensor/sorting.cc @@ -142,6 +142,7 @@ StructInfo InferStructInfoTopK(const Call& call, const BlockBuilder& ctx) { return output_sinfos[1]; } LOG(FATAL) << "Unsupported ret type: " << ret_type; + TVM_FFI_UNREACHABLE(); } TVM_REGISTER_OP("relax.topk") diff --git a/src/relax/transform/lower_alloc_tensor.cc b/src/relax/transform/lower_alloc_tensor.cc index 1c3e283768..13705c0908 100644 --- a/src/relax/transform/lower_alloc_tensor.cc +++ b/src/relax/transform/lower_alloc_tensor.cc @@ -59,6 +59,7 @@ class Mutator : public ExprMutator { LOG(FATAL) << "Shape argument for " << alloc_tensor_op << " should be a ShapeExpr, " << "or a variable that holds a ShapeExpr. " << "However, received argument " << shape_arg << " with struct info " << sinfo; + TVM_FFI_UNREACHABLE(); }(); PrimExpr nbytes = [&]() -> PrimExpr { diff --git a/src/relax/transform/realize_vdevice.cc b/src/relax/transform/realize_vdevice.cc index 5208767fea..9e6ebbbc26 100644 --- a/src/relax/transform/realize_vdevice.cc +++ b/src/relax/transform/realize_vdevice.cc @@ -72,6 +72,7 @@ class VDeviceLookup { LOG(FATAL) << "ValueError: " << "Expected to find device with type " << device_id << " and id " << device_id << ", but no such device was found in the IRModule's \"vdevice\" annotation"; + TVM_FFI_UNREACHABLE(); } private: diff --git a/src/relax/transform/tuning_api/database.cc b/src/relax/transform/tuning_api/database.cc index c4e706ebe4..55c97d0b85 100644 --- a/src/relax/transform/tuning_api/database.cc +++ b/src/relax/transform/tuning_api/database.cc @@ -29,14 +29,6 @@ #include "../../../meta_schedule/utils.h" -namespace tvm { -namespace meta_schedule { - -void JSONFileAppendLine(const String& path, const std::string& line); -std::vector<ObjectRef> JSONFileReadLines(const String& path, int num_threads, bool allow_missing); - -} // namespace meta_schedule -} // namespace tvm namespace tvm { namespace relax { @@ -237,20 +229,20 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, // Load `n->workloads2idx_` from `path_workload` std::vector<meta_schedule::Workload> workloads; { - std::vector<ObjectRef> json_objs = + std::vector<Any> json_objs = meta_schedule::JSONFileReadLines(path_workload, num_threads, allow_missing); int n_objs = json_objs.size(); n->workloads2idx_.reserve(n_objs); workloads.reserve(n_objs); for (int i = 0; i < n_objs; ++i) { - meta_schedule::Workload workload = meta_schedule::Workload::FromJSON(json_objs[i]); + meta_schedule::Workload workload = meta_schedule::Workload::FromJSON(json_objs[i].cast<ObjectRef>()); n->workloads2idx_.emplace(workload, i); workloads.push_back(workload); } } // Load `n->tuning_records_` from `path_tuning_record` { - std::vector<ObjectRef> json_objs = + std::vector<Any> json_objs = meta_schedule::JSONFileReadLines(path_tuning_record, num_threads, allow_missing); std::vector<int> workload_idxs; @@ -262,7 +254,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, records.resize(size, TuningRecord{nullptr}); support::parallel_for_dynamic( 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { - const ObjectRef& json_obj = json_objs[task_id]; + const ObjectRef& json_obj = json_objs[task_id].cast<ObjectRef>(); try { const ArrayObj* arr = json_obj.as<ArrayObj>(); ICHECK_EQ(arr->size(), 3); @@ -283,7 +275,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, // Load `n->measuremet_log` from `path_measurement_record` { - std::vector<ObjectRef> json_objs = + std::vector<Any> json_objs = meta_schedule::JSONFileReadLines(path_measurement_record, num_threads, allow_missing); std::vector<int> workload_idxs; std::vector<Target> targets; @@ -294,7 +286,7 @@ Database Database::JSONDatabase(String path_workload, String path_tuning_record, measurements.resize(size, Array<FloatImm>({})); support::parallel_for_dynamic( 0, json_objs.size(), num_threads, [&](int thread_id, int task_id) { - const ObjectRef& json_obj = json_objs[task_id]; + const ObjectRef& json_obj = json_objs[task_id].cast<ObjectRef>(); try { const ArrayObj* arr = json_obj.as<ArrayObj>(); ICHECK_EQ(arr->size(), 3); diff --git a/src/runtime/thread_pool.cc b/src/runtime/thread_pool.cc index 7537df718f..de7f38ed28 100644 --- a/src/runtime/thread_pool.cc +++ b/src/runtime/thread_pool.cc @@ -98,7 +98,7 @@ class ParallelLauncher { // Wait n jobs to finish int WaitForJobs() { while (num_pending_.load() != 0) { - tvm::runtime::threading::Yield(); + tvm::runtime::threading::YieldThread(); } if (!has_error_.load()) return 0; std::ostringstream os; @@ -163,7 +163,7 @@ class SpscTaskQueue { */ void Push(const Task& input) { while (!Enqueue(input)) { - tvm::runtime::threading::Yield(); + tvm::runtime::threading::YieldThread(); } if (pending_.fetch_add(1) == -1) { std::unique_lock<std::mutex> lock(mutex_); @@ -182,7 +182,7 @@ class SpscTaskQueue { // If a new task comes to the queue quickly, this wait avoid the worker from sleeping. // The default spin count is set by following the typical omp convention for (uint32_t i = 0; i < spin_count && pending_.load() == 0; ++i) { - tvm::runtime::threading::Yield(); + tvm::runtime::threading::YieldThread(); } if (pending_.fetch_sub(1) == 0) { std::unique_lock<std::mutex> lock(mutex_); @@ -511,7 +511,7 @@ int TVMBackendParallelBarrier(int task_id, TVMParallelGroupEnv* penv) { for (int i = 0; i < num_task; ++i) { if (i != task_id) { while (sync_counter[i * kSyncStride].load(std::memory_order_relaxed) <= old_counter) { - tvm::runtime::threading::Yield(); + tvm::runtime::threading::YieldThread(); } } } diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index 177ecf5110..01c0f1603f 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -369,7 +369,7 @@ int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0 return impl_->Configure(mode, nthreads, exclude_worker0, cpus); } -void Yield() { +void YieldThread() { #ifdef __hexagon__ // QuRT doesn't have a yield API, so instead we sleep for the minimum amount // of time to let the OS schedule another thread. std::this_thread::yield() diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc index 35a9f35db4..2ae5663385 100644 --- a/src/script/printer/relax/tir.cc +++ b/src/script/printer/relax/tir.cc @@ -66,6 +66,7 @@ Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) { return doc.value(); } LOG(FATAL) << "IndexError: Variable is not defined in the environment: " << n; + TVM_FFI_UNREACHABLE(); } TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable).set_dispatch<tir::Var>("relax", PrintTIRVar); diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 87db53061c..18c7afe504 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -317,6 +317,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) // return doc.value(); } LOG(FATAL) << "IndexError: Buffer is not defined in the environment: " << buffer; + TVM_FFI_UNREACHABLE(); }); TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) diff --git a/src/tir/analysis/stmt_finding.cc b/src/tir/analysis/stmt_finding.cc index 300f779ae9..527c7d04c6 100644 --- a/src/tir/analysis/stmt_finding.cc +++ b/src/tir/analysis/stmt_finding.cc @@ -98,6 +98,7 @@ Stmt GetEnclosingLoop(const BlockNode* block, Stmt func_body) { } LOG(FATAL) << "Enclosing loop not found for a block " << GetRef<Block>(block); + TVM_FFI_UNREACHABLE(); } const BlockNode* FindAnchorBlock(const IRModule& mod) { diff --git a/tests/cpp/llvm_codegen_registry_test.cc b/tests/cpp/llvm_codegen_registry_test.cc index 5e2a167214..534d4c8e41 100644 --- a/tests/cpp/llvm_codegen_registry_test.cc +++ b/tests/cpp/llvm_codegen_registry_test.cc @@ -55,7 +55,7 @@ TEST(LLVMCodeGen, CodeGenFactoryWorks) { std::initializer_list<std::string> all_targets = {ALL_TARGETS}; for (const std::string& s : all_targets) { if (auto pf = tvm::ffi::Function::GetGlobal("tvm.codegen.llvm.target_" + s)) { - auto cg = static_cast<void*>((*pf)()); + auto cg = (*pf)().cast<void*>(); EXPECT_NE(cg, nullptr); delete static_cast<tvm::codegen::CodeGenLLVM*>(cg); } diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc index baa5e24bb5..17e3cae4ad 100644 --- a/tests/cpp/target_test.cc +++ b/tests/cpp/target_test.cc @@ -462,7 +462,7 @@ TEST(TargetCreation, DetectSystemTriple) { } Optional<String> mtriple = target->GetAttr<String>("mtriple"); - ASSERT_TRUE(mtriple.value() == String((*pf)())); + ASSERT_TRUE(mtriple.value() == (*pf)().cast<String>()); } #endif
