This is an automated email from the ASF dual-hosted git repository. junrushao pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 0a3e178 [metal] update language version (#7116) 0a3e178 is described below commit 0a3e1783b30910f4496e437b0ebefd998bf5a935 Author: Bing Xu <antinucl...@gmail.com> AuthorDate: Tue Dec 15 23:24:46 2020 -0800 [metal] update language version (#7116) * [metal] update language version * fix mps --- src/runtime/contrib/mps/conv.mm | 9 ++++++--- src/runtime/metal/metal_module.mm | 3 +-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/runtime/contrib/mps/conv.mm b/src/runtime/contrib/mps/conv.mm index 3b16f08..b860ee2 100644 --- a/src/runtime/contrib/mps/conv.mm +++ b/src/runtime/contrib/mps/conv.mm @@ -34,7 +34,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mps.buffer2img").set_body([](TVMArgs args, TVMR id<MTLDevice> dev = entry_ptr->metal_api->GetDevice(buf->ctx); id<MTLBuffer> temp = rt->GetTempBuffer(buf->ctx, [mtlbuf length]); entry_ptr->metal_api->CopyDataFromTo((__bridge void*)mtlbuf, 0, (__bridge void*)temp, 0, - [mtlbuf length], buf -> ctx, buf -> ctx, nullptr); + [mtlbuf length], buf -> ctx, buf -> ctx, buf -> dtype, + nullptr); MPSImageDescriptor* desc = [MPSImageDescriptor imageDescriptorWithChannelFormat:MPSImageFeatureChannelFormatFloat32 @@ -69,7 +70,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mps.img2buffer").set_body([](TVMArgs args, TVMR imageIndex:0]; entry_ptr->metal_api->CopyDataFromTo((__bridge void*)temp, 0, (__bridge void*)mtlbuf, 0, - [mtlbuf length], buf -> ctx, buf -> ctx, nullptr); + [mtlbuf length], buf -> ctx, buf -> ctx, buf -> dtype, + nullptr); }); TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d").set_body([](TVMArgs args, TVMRetValue* ret) { @@ -111,7 +113,8 @@ TVM_REGISTER_GLOBAL("tvm.contrib.mps.conv2d").set_body([](TVMArgs args, TVMRetVa id<MTLBuffer> bufB = (__bridge id<MTLBuffer>)(weight->data); id<MTLBuffer> tempB = rt->GetTempBuffer(weight->ctx, [bufB length]); entry_ptr->metal_api->CopyDataFromTo((__bridge void*)bufB, 0, (__bridge void*)tempB, 0, - [bufB length], weight -> ctx, weight -> ctx, nullptr); + [bufB length], weight -> ctx, weight -> ctx, tmp_in.dtype, + nullptr); float* ptr_w = (float*)[tempB contents]; // output to MPSImage DLTensor tmp_out; diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 7d46811..981dd61 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -88,8 +88,7 @@ class MetalModuleNode final : public runtime::ModuleNode { if (e.lib == nil) { if (fmt_ == "metal") { MTLCompileOptions* opts = [MTLCompileOptions alloc]; - // Use the Metal 1.2 for now. - opts.languageVersion = MTLLanguageVersion1_2; + opts.languageVersion = MTLLanguageVersion2_3; opts.fastMathEnabled = YES; // opts = nil; e.lib = [w->devices[device_id]