This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s2
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/refactor-s2 by this push:
     new 6a41ea59d7 Followup testcase fixes
6a41ea59d7 is described below

commit 6a41ea59d777e0fc5675238b431c1d46e48b647e
Author: tqchen <[email protected]>
AuthorDate: Sat May 3 09:29:44 2025 -0400

    Followup testcase fixes
---
 ffi/tests/cpp/test_tuple.cc                  | 2 +-
 src/runtime/contrib/dnnl/dnnl.cc             | 5 +++--
 src/runtime/contrib/tflite/tflite_runtime.cc | 9 +++++----
 3 files changed, 9 insertions(+), 7 deletions(-)

diff --git a/ffi/tests/cpp/test_tuple.cc b/ffi/tests/cpp/test_tuple.cc
index 1fe9ca74e8..02a258522c 100644
--- a/ffi/tests/cpp/test_tuple.cc
+++ b/ffi/tests/cpp/test_tuple.cc
@@ -78,7 +78,7 @@ TEST(Tuple, AnyConvert) {
 
   Any any0 = view0;
   // trigger a copy due to implict conversion
-  auto tuple2 = any0.cast<Tuple<TPrimExpr, TInt>>() ;
+  auto tuple2 = any0.cast<Tuple<TPrimExpr, TInt>>();
   EXPECT_TRUE(!tuple0.same_as(tuple2));
   EXPECT_EQ(tuple2.get<0>()->value, 1);
   EXPECT_EQ(tuple2.get<1>()->value, 2);
diff --git a/src/runtime/contrib/dnnl/dnnl.cc b/src/runtime/contrib/dnnl/dnnl.cc
index 822e8ac376..7fca0d6b26 100644
--- a/src/runtime/contrib/dnnl/dnnl.cc
+++ b/src/runtime/contrib/dnnl/dnnl.cc
@@ -352,8 +352,9 @@ 
TVM_REGISTER_GLOBAL("tvm.contrib.dnnl.conv2d").set_body_packed([](TVMArgs args,
   auto input = args[0].cast<DLTensor*>();
   auto weights = args[1].cast<DLTensor*>();
   auto output = args[2].cast<DLTensor*>();
-  int p_Ph0_ = args[3], p_Pw0_ = args[4], p_Ph1_ = args[5], p_Pw1_ = args[6], 
p_Sh_ = args[7],
-      p_Sw_ = args[8], p_G_ = args[9];
+  int p_Ph0_ = args[3].cast<int>(), p_Pw0_ = args[4].cast<int>(), p_Ph1_ = 
args[5].cast<int>(),
+      p_Pw1_ = args[6].cast<int>(), p_Sh_ = args[7].cast<int>(), p_Sw_ = 
args[8].cast<int>(),
+      p_G_ = args[9].cast<int>();
   bool channel_last = args[10].cast<bool>();
   bool pre_cast = args[11].cast<bool>();
   bool post_cast = args[12].cast<bool>();
diff --git a/src/runtime/contrib/tflite/tflite_runtime.cc 
b/src/runtime/contrib/tflite/tflite_runtime.cc
index 82257e2a82..7d7f22690c 100644
--- a/src/runtime/contrib/tflite/tflite_runtime.cc
+++ b/src/runtime/contrib/tflite/tflite_runtime.cc
@@ -156,11 +156,12 @@ PackedFunc TFLiteRuntime::GetFunction(const String& name, 
const ObjectPtr<Object
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
       int in_idx = args[0].cast<int>();
       ICHECK_GE(in_idx, 0);
-      this->SetInput(in_idx, args[1]);
+      this->SetInput(in_idx, args[1].cast<DLTensor*>());
     });
   } else if (name == "get_output") {
-    return PackedFunc(
-        [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = 
this->GetOutput(args[0]); });
+    return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
+      *rv = this->GetOutput(args[0].cast<int>());
+    });
   } else if (name == "invoke") {
     return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { 
this->Invoke(); });
   } else if (name == "set_num_threads") {
@@ -181,7 +182,7 @@ Module TFLiteRuntimeCreate(const std::string& 
tflite_model_bytes, Device dev) {
 }
 
 TVM_REGISTER_GLOBAL("tvm.tflite_runtime.create").set_body_packed([](TVMArgs 
args, TVMRetValue* rv) {
-  *rv = TFLiteRuntimeCreate(args[0], args[1]);
+  *rv = TFLiteRuntimeCreate(args[0].get<std::string>(), args[1].get<Device>());
 });
 
 
TVM_REGISTER_GLOBAL("target.runtime.tflite").set_body_typed(TFLiteRuntimeCreate);

Reply via email to