This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch refactor-s2
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/refactor-s2 by this push:
new 6a41ea59d7 Followup testcase fixes
6a41ea59d7 is described below
commit 6a41ea59d777e0fc5675238b431c1d46e48b647e
Author: tqchen <[email protected]>
AuthorDate: Sat May 3 09:29:44 2025 -0400
Followup testcase fixes
---
ffi/tests/cpp/test_tuple.cc | 2 +-
src/runtime/contrib/dnnl/dnnl.cc | 5 +++--
src/runtime/contrib/tflite/tflite_runtime.cc | 9 +++++----
3 files changed, 9 insertions(+), 7 deletions(-)
diff --git a/ffi/tests/cpp/test_tuple.cc b/ffi/tests/cpp/test_tuple.cc
index 1fe9ca74e8..02a258522c 100644
--- a/ffi/tests/cpp/test_tuple.cc
+++ b/ffi/tests/cpp/test_tuple.cc
@@ -78,7 +78,7 @@ TEST(Tuple, AnyConvert) {
Any any0 = view0;
// trigger a copy due to implict conversion
- auto tuple2 = any0.cast<Tuple<TPrimExpr, TInt>>() ;
+ auto tuple2 = any0.cast<Tuple<TPrimExpr, TInt>>();
EXPECT_TRUE(!tuple0.same_as(tuple2));
EXPECT_EQ(tuple2.get<0>()->value, 1);
EXPECT_EQ(tuple2.get<1>()->value, 2);
diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc
index 822e8ac376..7fca0d6b26 100644
--- a/src/runtime/contrib/dnnl/dnnl.cc
+++ b/src/runtime/contrib/dnnl/dnnl.cc
@@ -352,8 +352,9 @@
TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d").set_body_packed([](TVMArgs args,
auto input = args[0].cast<DLTensor*>();
auto weights = args[1].cast<DLTensor*>();
auto output = args[2].cast<DLTensor*>();
- int p_Ph0_ = args[3], p_Pw0_ = args[4], p_Ph1_ = args[5], p_Pw1_ = args[6],
p_Sh_ = args[7],
- p_Sw_ = args[8], p_G_ = args[9];
+ int p_Ph0_ = args[3].cast<int>(), p_Pw0_ = args[4].cast<int>(), p_Ph1_ =
args[5].cast<int>(),
+ p_Pw1_ = args[6].cast<int>(), p_Sh_ = args[7].cast<int>(), p_Sw_ =
args[8].cast<int>(),
+ p_G_ = args[9].cast<int>();
bool channel_last = args[10].cast<bool>();
bool pre_cast = args[11].cast<bool>();
bool post_cast = args[12].cast<bool>();
diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc
b/src/runtime/contrib/tflite/tflite_runtime.cc
index 82257e2a82..7d7f22690c 100644
--- a/src/runtime/contrib/tflite/tflite_runtime.cc
+++ b/src/runtime/contrib/tflite/tflite_runtime.cc
@@ -156,11 +156,12 @@ PackedFunc TFLiteRuntime::GetFunction(const String& name,
const ObjectPtr<Object
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
int in_idx = args[0].cast<int>();
ICHECK_GE(in_idx, 0);
- this->SetInput(in_idx, args[1]);
+ this->SetInput(in_idx, args[1].cast<DLTensor*>());
});
} else if (name == "get_output") {
- return PackedFunc(
- [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv =
this->GetOutput(args[0]); });
+ return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+ *rv = this->GetOutput(args[0].cast<int>());
+ });
} else if (name == "invoke") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->Invoke(); });
} else if (name == "set_num_threads") {
@@ -181,7 +182,7 @@ Module TFLiteRuntimeCreate(const std::string&
tflite_model_bytes, Device dev) {
}
TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body_packed([](TVMArgs
args, TVMRetValue* rv) {
- *rv = TFLiteRuntimeCreate(args[0], args[1]);
+ *rv = TFLiteRuntimeCreate(args[0].get<std::string>(), args[1].get<Device>());
});
TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate);