This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2cbbcd5e2c [Refactor][Metal] Update ICHECK to TVM_FFI_ICHECK in Metal 
runtime (#18811)
2cbbcd5e2c is described below

commit 2cbbcd5e2ce774d338744fa84a5719abc6125bc2
Author: Bryan <[email protected]>
AuthorDate: Mon Feb 23 07:29:37 2026 -0500

    [Refactor][Metal] Update ICHECK to TVM_FFI_ICHECK in Metal runtime (#18811)
    
    This commit updates all ICHECK macros to TVM_FFI_ICHECK in the Metal
    runtime implementation to align with the new FFI refactoring.
    
    The changes include updating ICHECK, ICHECK_LT, ICHECK_EQ macros to
    their TVM_FFI_ICHECK equivalents across the following files:
    - src/runtime/metal/metal_device_api.mm
    - src/runtime/metal/metal_module.mm
    
    This refactoring ensures consistency with the new TVM FFI interface and
    maintains the same error checking behavior while using the updated macro
    names.
---
 src/runtime/metal/metal_device_api.mm | 23 ++++++++++++-----------
 src/runtime/metal/metal_module.mm     | 20 ++++++++++----------
 2 files changed, 22 insertions(+), 21 deletions(-)

diff --git a/src/runtime/metal/metal_device_api.mm 
b/src/runtime/metal/metal_device_api.mm
index c0218a5bf2..5ff9c2dfcd 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -48,7 +48,7 @@ void MetalWorkspace::GetAttr(Device dev, DeviceAttrKind kind, 
ffi::Any* rv) {
       *rv = int(index < devices.size());
       return;
     }
-    ICHECK_LT(index, devices.size()) << "Invalid device id " << index;
+    TVM_FFI_ICHECK_LT(index, devices.size()) << "Invalid device id " << index;
     switch (kind) {
       case kMaxThreadsPerBlock: {
         *rv = static_cast<int>([devices[dev.device_id] 
maxThreadsPerThreadgroup].width);
@@ -125,11 +125,11 @@ int GetWarpSize(id<MTLDevice> dev) {
   id<MTLLibrary> lib = [dev newLibraryWithSource:[NSString 
stringWithUTF8String:kDummyKernel]
                                          options:nil
                                            error:&error_msg];
-  ICHECK(lib != nil) << [[error_msg localizedDescription] UTF8String];
+  TVM_FFI_ICHECK(lib != nil) << [[error_msg localizedDescription] UTF8String];
   id<MTLFunction> f = [lib newFunctionWithName:[NSString 
stringWithUTF8String:"CopyKernel"]];
-  ICHECK(f != nil);
+  TVM_FFI_ICHECK(f != nil);
   id<MTLComputePipelineState> state = [dev 
newComputePipelineStateWithFunction:f error:&error_msg];
-  ICHECK(state != nil) << [[error_msg localizedDescription] UTF8String];
+  TVM_FFI_ICHECK(state != nil) << [[error_msg localizedDescription] 
UTF8String];
   int size = static_cast<int>(state.threadExecutionWidth);
   [state release];
   [f release];
@@ -193,7 +193,7 @@ void* MetalWorkspace::AllocDataSpace(Device device, size_t 
nbytes, size_t alignm
     #endif
     */
     buf = [dev newBufferWithLength:nbytes options:storage_mode];
-    ICHECK(buf != nil);
+    TVM_FFI_ICHECK(buf != nil);
   };
   return (void*)(buf);
 }
@@ -214,8 +214,8 @@ void MetalWorkspace::FreeDataSpace(Device dev, void* ptr) {
 
 Stream* MetalWorkspace::CastStreamOrGetDefault(TVMStreamHandle stream, int 
device_id) {
   if (stream != nullptr) return static_cast<Stream*>(stream);
-  ICHECK_LT(static_cast<size_t>(device_id), default_streams_.size());
-  ICHECK(default_streams_[device_id] != nullptr);
+  TVM_FFI_ICHECK_LT(static_cast<size_t>(device_id), default_streams_.size());
+  TVM_FFI_ICHECK(default_streams_[device_id] != nullptr);
   return default_streams_[device_id];
 }
 
@@ -234,7 +234,8 @@ void MetalWorkspace::CopyDataFromTo(const void* from, 
size_t from_offset, void*
     int to_dev_type = static_cast<int>(dev_to.device_type);
 
     if (from_dev_type == kDLMetal && to_dev_type == kDLMetal) {
-      ICHECK_EQ(dev_from.device_id, dev_to.device_id) << "Metal disallow cross 
device copy.";
+      TVM_FFI_ICHECK_EQ(dev_from.device_id, dev_to.device_id)
+          << "Metal disallow cross device copy.";
       id<MTLBlitCommandEncoder> encoder = [cb blitCommandEncoder];
       [encoder copyFromBuffer:(id<MTLBuffer>)(from)
                  sourceOffset:from_offset
@@ -287,14 +288,14 @@ void MetalWorkspace::CopyDataFromTo(const void* from, 
size_t from_offset, void*
 }
 
 TVMStreamHandle MetalWorkspace::CreateStream(Device dev) {
-  ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << 
dev.device_id;
+  TVM_FFI_ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << 
dev.device_id;
   Stream* stream = new Stream(devices[dev.device_id]);
   return static_cast<TVMStreamHandle>(stream);
 }
 
 void MetalWorkspace::FreeStream(Device dev, TVMStreamHandle stream) {
-  ICHECK(stream != nullptr);
-  ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << 
dev.device_id;
+  TVM_FFI_ICHECK(stream != nullptr);
+  TVM_FFI_ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << 
dev.device_id;
   delete static_cast<Stream*>(stream);
 }
 
diff --git a/src/runtime/metal/metal_module.mm 
b/src/runtime/metal/metal_module.mm
index cf1a1641be..deb863c69b 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -86,7 +86,7 @@ class MetalModuleNode final : public ffi::ModuleObj {
   // get a from primary context in device_id
   id<MTLComputePipelineState> GetPipelineState(size_t device_id, const 
std::string& func_name) {
     metal::MetalWorkspace* w = metal::MetalWorkspace::Global();
-    ICHECK_LT(device_id, w->devices.size());
+    TVM_FFI_ICHECK_LT(device_id, w->devices.size());
     // start lock scope.
     std::lock_guard<std::mutex> lock(mutex_);
     if (finfo_.size() <= device_id) {
@@ -100,7 +100,7 @@ class MetalModuleNode final : public ffi::ModuleObj {
     id<MTLLibrary> lib = nil;
     auto kernel = smap_.find(func_name);
     // Directly lookup kernels
-    ICHECK(kernel != smap_.end());
+    TVM_FFI_ICHECK(kernel != smap_.end());
     const std::string& source = kernel->second;
 
     if (fmt_ == "metal") {
@@ -132,18 +132,18 @@ class MetalModuleNode final : public ffi::ModuleObj {
       }
     }
     id<MTLFunction> f = [lib newFunctionWithName:[NSString 
stringWithUTF8String:func_name.c_str()]];
-    ICHECK(f != nil) << "cannot find function " << func_name;
+    TVM_FFI_ICHECK(f != nil) << "cannot find function " << func_name;
     id<MTLComputePipelineState> state =
         [w->devices[device_id] newComputePipelineStateWithFunction:f 
error:&err_msg];
-    ICHECK(state != nil) << "cannot get state:"
-                         << " for function " << func_name
-                         << [[err_msg localizedDescription] UTF8String];
+    TVM_FFI_ICHECK(state != nil) << "cannot get state:"
+                                 << " for function " << func_name
+                                 << [[err_msg localizedDescription] 
UTF8String];
     [f release];
     [lib release];
     // The state.threadExecutionWidth can change dynamically according
     // to the resource constraint in kernel, so it is not strictly hold
     // Turn of warp aware optimziation for now.
-    // ICHECK_EQ(state.threadExecutionWidth, w->warp_size[device_id]);
+    // TVM_FFI_ICHECK_EQ(state.threadExecutionWidth, w->warp_size[device_id]);
     if (e.smap[func_name] != nil) [e.smap[func_name] release];
     e.smap[func_name] = state;
     return state;
@@ -235,7 +235,7 @@ class MetalWrappedFunc {
       // attach error message with function name
       [cb addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
         if (buffer.status == MTLCommandBufferStatusError) {
-          ICHECK(buffer.error != nil);
+          TVM_FFI_ICHECK(buffer.error != nil);
           std::ostringstream os;
           os << "GPUError happens after running " << func_name_ << ": "
              << buffer.error.localizedDescription.UTF8String;
@@ -270,7 +270,7 @@ ffi::Optional<ffi::Function> 
MetalModuleNode::GetFunction(const ffi::String& nam
   ffi::Function ret;
   AUTORELEASEPOOL {
     ObjectPtr<Object> sptr_to_self = ffi::GetObjectPtr<Object>(this);
-    ICHECK_EQ(sptr_to_self.get(), this);
+    TVM_FFI_ICHECK_EQ(sptr_to_self.get(), this);
     auto opt_info = fmap_.Get(name);
     if (!opt_info.has_value()) {
       return;
@@ -325,7 +325,7 @@ ffi::Module MetalModuleLoadFromBytes(const ffi::Bytes& 
bytes) {
 
   stream.Read(&ver);
   stream.Read(&smap);
-  ICHECK(stream.Read(&fmap));
+  TVM_FFI_ICHECK(stream.Read(&fmap));
   stream.Read(&fmt);
 
   return MetalModuleCreate(smap, fmap, fmt, "");

Reply via email to