This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 8af49f3c3abc9b58effd486cee2ff9121f752256 Author: tqchen <[email protected]> AuthorDate: Sat May 3 09:29:44 2025 -0400 Followup testcase fixes --- ffi/src/ffi/traceback_win.cc | 1 - ffi/tests/cpp/test_tuple.cc | 2 +- src/relax/backend/contrib/clml/codegen.cc | 2 +- src/relax/backend/contrib/cudnn/codegen.cc | 2 +- src/runtime/contrib/dnnl/dnnl.cc | 5 +++-- src/runtime/contrib/tflite/tflite_runtime.cc | 9 +++++---- tests/cpp-runtime/hexagon/run_all_tests.cc | 2 +- tests/cpp-runtime/hexagon/run_unit_tests.cc | 2 +- 8 files changed, 13 insertions(+), 12 deletions(-) diff --git a/ffi/src/ffi/traceback_win.cc b/ffi/src/ffi/traceback_win.cc index 1de4c88681..cc2a0d1adc 100644 --- a/ffi/src/ffi/traceback_win.cc +++ b/ffi/src/ffi/traceback_win.cc @@ -22,7 +22,6 @@ * \note We use the term "traceback" to be consistent with python naming convention. */ #ifdef _MSC_VER - #include <tvm/ffi/c_api.h> #include <tvm/ffi/error.h> diff --git a/ffi/tests/cpp/test_tuple.cc b/ffi/tests/cpp/test_tuple.cc index 1fe9ca74e8..02a258522c 100644 --- a/ffi/tests/cpp/test_tuple.cc +++ b/ffi/tests/cpp/test_tuple.cc @@ -78,7 +78,7 @@ TEST(Tuple, AnyConvert) { Any any0 = view0; // trigger a copy due to implict conversion - auto tuple2 = any0.cast<Tuple<TPrimExpr, TInt>>() ; + auto tuple2 = any0.cast<Tuple<TPrimExpr, TInt>>(); EXPECT_TRUE(!tuple0.same_as(tuple2)); EXPECT_EQ(tuple2.get<0>()->value, 1); EXPECT_EQ(tuple2.get<1>()->value, 2); diff --git a/src/relax/backend/contrib/clml/codegen.cc b/src/relax/backend/contrib/clml/codegen.cc index 7d3d243fc9..3c87079f99 100644 --- a/src/relax/backend/contrib/clml/codegen.cc +++ b/src/relax/backend/contrib/clml/codegen.cc @@ -317,7 +317,7 @@ Array<runtime::Module> OpenCLMLCompiler(Array<Function> functions, Map<String, A const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.clml_runtime_create"); std::string func_name = GetExtSymbol(func); VLOG(1) << "Creating clml runtime::Module for '" << func_name << "'"; - compiled_functions.push_back(pf(func_name, graph_json, constant_names)); + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast<runtime::Module>()); } return compiled_functions; } diff --git a/src/relax/backend/contrib/cudnn/codegen.cc b/src/relax/backend/contrib/cudnn/codegen.cc index e183dadc26..43933c7d2a 100644 --- a/src/relax/backend/contrib/cudnn/codegen.cc +++ b/src/relax/backend/contrib/cudnn/codegen.cc @@ -143,7 +143,7 @@ Array<runtime::Module> cuDNNCompiler(Array<Function> functions, Map<String, ffi: auto constant_names = serializer.GetConstantNames(); const auto pf = tvm::ffi::Function::GetGlobalRequired("runtime.cuDNNJSONRuntimeCreate"); auto func_name = GetExtSymbol(func); - compiled_functions.push_back(pf(func_name, graph_json, constant_names)); + compiled_functions.push_back(pf(func_name, graph_json, constant_names).cast<runtime::Module>()); } return compiled_functions; diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc index 822e8ac376..7fca0d6b26 100644 --- a/src/runtime/contrib/dnnl/dnnl.cc +++ b/src/runtime/contrib/dnnl/dnnl.cc @@ -352,8 +352,9 @@ TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d").set_body_packed([](TVMArgs args, auto input = args[0].cast<DLTensor*>(); auto weights = args[1].cast<DLTensor*>(); auto output = args[2].cast<DLTensor*>(); - int p_Ph0_ = args[3], p_Pw0_ = args[4], p_Ph1_ = args[5], p_Pw1_ = args[6], p_Sh_ = args[7], - p_Sw_ = args[8], p_G_ = args[9]; + int p_Ph0_ = args[3].cast<int>(), p_Pw0_ = args[4].cast<int>(), p_Ph1_ = args[5].cast<int>(), + p_Pw1_ = args[6].cast<int>(), p_Sh_ = args[7].cast<int>(), p_Sw_ = args[8].cast<int>(), + p_G_ = args[9].cast<int>(); bool channel_last = args[10].cast<bool>(); bool pre_cast = args[11].cast<bool>(); bool post_cast = args[12].cast<bool>(); diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc b/src/runtime/contrib/tflite/tflite_runtime.cc index 82257e2a82..09669ac370 100644 --- a/src/runtime/contrib/tflite/tflite_runtime.cc +++ b/src/runtime/contrib/tflite/tflite_runtime.cc @@ -156,11 +156,12 @@ PackedFunc TFLiteRuntime::GetFunction(const String& name, const ObjectPtr<Object return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { int in_idx = args[0].cast<int>(); ICHECK_GE(in_idx, 0); - this->SetInput(in_idx, args[1]); + this->SetInput(in_idx, args[1].cast<DLTensor*>()); }); } else if (name == "get_output") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetOutput(args[0]); }); + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetOutput(args[0].cast<int>()); + }); } else if (name == "invoke") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Invoke(); }); } else if (name == "set_num_threads") { @@ -181,7 +182,7 @@ Module TFLiteRuntimeCreate(const std::string& tflite_model_bytes, Device dev) { } TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body_packed([](TVMArgs args, TVMRetValue* rv) { - *rv = TFLiteRuntimeCreate(args[0], args[1]); + *rv = TFLiteRuntimeCreate(args[0].cast<std::string>(), args[1].cast<Device>()); }); TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate); diff --git a/tests/cpp-runtime/hexagon/run_all_tests.cc b/tests/cpp-runtime/hexagon/run_all_tests.cc index 5187277dcb..0f9c1cb7b5 100644 --- a/tests/cpp-runtime/hexagon/run_all_tests.cc +++ b/tests/cpp-runtime/hexagon/run_all_tests.cc @@ -41,7 +41,7 @@ namespace hexagon { TVM_REGISTER_GLOBAL("hexagon.run_all_tests").set_body_packed([](TVMArgs args, TVMRetValue* rv) { // gtest args are passed into this packed func as a singular string // split gtest args using <space> delimiter and build argument vector - std::vector<std::string> parsed_args = tvm::support::Split(args[0], ' '); + std::vector<std::string> parsed_args = tvm::support::Split(args[0].cast<std::string>(), ' '); std::vector<char*> argv; // add executable name diff --git a/tests/cpp-runtime/hexagon/run_unit_tests.cc b/tests/cpp-runtime/hexagon/run_unit_tests.cc index 37a521457a..59059fc803 100644 --- a/tests/cpp-runtime/hexagon/run_unit_tests.cc +++ b/tests/cpp-runtime/hexagon/run_unit_tests.cc @@ -83,7 +83,7 @@ class GtestPrinter : public testing::EmptyTestEventListener { TVM_REGISTER_GLOBAL("hexagon.run_unit_tests").set_body_packed([](TVMArgs args, TVMRetValue* rv) { // gtest args are passed into this packed func as a singular string // split gtest args using <space> delimiter and build argument vector - std::vector<std::string> parsed_args = tvm::support::Split(args[0], ' '); + std::vector<std::string> parsed_args = tvm::support::Split(args[0].cast<std::string>(), ' '); std::vector<char*> argv; // add executable name
