This is an automated email from the ASF dual-hosted git repository. wuwei pushed a commit to branch vk-i64 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit ef1ed2d732000ba2c7bf68bd6a1342ca3c98e0df Author: Masahiro Masuda <masahi...@gmail.com> AuthorDate: Wed Mar 3 08:23:16 2021 +0900 update metal runtime to use ArgUnion64 (not tested) --- src/runtime/metal/metal_module.mm | 4 ++-- src/runtime/pack_args.h | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 981dd61..8f1fde8 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -180,7 +180,7 @@ class MetalWrappedFunc { scache_[dev_id] = m->GetPipelineState(dev_id, func_name); } // invoke the function with void arguments - void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const { + void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const { metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal(); int device_id = t->context.device_id; if (scache_[device_id] == nil) { @@ -197,7 +197,7 @@ class MetalWrappedFunc { } if (num_pack_args_ != 0) { [encoder setBytes:pack_args - length:num_pack_args_ * sizeof(ArgUnion) + length:num_pack_args_ * sizeof(ArgUnion64) atIndex:num_buffer_args_]; } // launch diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 54a75d6..2e7a881 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -40,7 +40,6 @@ namespace tvm { namespace runtime { /*! * \brief argument union type of 32bit. - * Choose 32 bit because most GPU API do not work well with 64 bit. */ union ArgUnion { int32_t v_int32; @@ -48,6 +47,9 @@ union ArgUnion { float v_float32; }; +/*! + * \brief argument union type of 64 bit, for use by Vulkan and Metal runtime. + */ union ArgUnion64 { int32_t v_int32[2]; uint32_t v_uint32[2];