This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0a3e178  [metal] update language version (#7116)
0a3e178 is described below

commit 0a3e1783b30910f4496e437b0ebefd998bf5a935
Author: Bing Xu <antinucl...@gmail.com>
AuthorDate: Tue Dec 15 23:24:46 2020 -0800

    [metal] update language version (#7116)
    
    * [metal] update language version
    
    * fix mps
---
 src/runtime/contrib/mps/conv.mm   | 9 ++++++---
 src/runtime/metal/metal_module.mm | 3 +--
 2 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm
index 3b16f08..b860ee2 100644
--- a/src/runtime/contrib/mps/conv.mm
+++ b/src/runtime/contrib/mps/conv.mm
@@ -34,7 +34,8 @@ 
TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img").set_body([](TVMArgs args, TVMR
   id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(buf->ctx);
   id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]);
   entry_ptr->metal_api->CopyDataFromTo((__bridge void*)mtlbuf, 0, (__bridge 
void*)temp, 0,
-                                       [mtlbuf length], buf -> ctx, buf -> 
ctx, nullptr);
+                                       [mtlbuf length], buf -> ctx, buf -> 
ctx, buf -> dtype,
+                                       nullptr);
 
   MPSImageDescriptor* desc =
       [MPSImageDescriptor 
imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32
@@ -69,7 +70,8 @@ 
TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer").set_body([](TVMArgs args, TVMR
          imageIndex:0];
 
   entry_ptr->metal_api->CopyDataFromTo((__bridge void*)temp, 0, (__bridge 
void*)mtlbuf, 0,
-                                       [mtlbuf length], buf -> ctx, buf -> 
ctx, nullptr);
+                                       [mtlbuf length], buf -> ctx, buf -> 
ctx, buf -> dtype,
+                                       nullptr);
 });
 
 TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d").set_body([](TVMArgs args, 
TVMRetValue* ret) {
@@ -111,7 +113,8 @@ 
TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d").set_body([](TVMArgs args, TVMRetVa
   id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(weight->data);
   id<MTLBuffer> tempB = rt->GetTempBuffer(weight->ctx, [bufB length]);
   entry_ptr->metal_api->CopyDataFromTo((__bridge void*)bufB, 0, (__bridge 
void*)tempB, 0,
-                                       [bufB length], weight -> ctx, weight -> 
ctx, nullptr);
+                                       [bufB length], weight -> ctx, weight -> 
ctx, tmp_in.dtype,
+                                       nullptr);
   float* ptr_w = (float*)[tempB contents];
   // output to MPSImage
   DLTensor tmp_out;
diff --git a/src/runtime/metal/metal_module.mm 
b/src/runtime/metal/metal_module.mm
index 7d46811..981dd61 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -88,8 +88,7 @@ class MetalModuleNode final : public runtime::ModuleNode {
     if (e.lib == nil) {
       if (fmt_ == "metal") {
         MTLCompileOptions* opts = [MTLCompileOptions alloc];
-        // Use the Metal 1.2 for now.
-        opts.languageVersion = MTLLanguageVersion1_2;
+        opts.languageVersion = MTLLanguageVersion2_3;
         opts.fastMathEnabled = YES;
         // opts = nil;
         e.lib = [w->devices[device_id]

Reply via email to