This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch vk-i64
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit ef1ed2d732000ba2c7bf68bd6a1342ca3c98e0df
Author: Masahiro Masuda <masahi...@gmail.com>
AuthorDate: Wed Mar 3 08:23:16 2021 +0900

    update metal runtime to use ArgUnion64 (not tested)
---
 src/runtime/metal/metal_module.mm | 4 ++--
 src/runtime/pack_args.h           | 4 +++-
 2 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/src/runtime/metal/metal_module.mm 
b/src/runtime/metal/metal_module.mm
index 981dd61..8f1fde8 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -180,7 +180,7 @@ class MetalWrappedFunc {
     scache_[dev_id] = m->GetPipelineState(dev_id, func_name);
   }
   // invoke the function with void arguments
-  void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) 
const {
+  void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) 
const {
     metal::MetalThreadEntry* t = metal::MetalThreadEntry::ThreadLocal();
     int device_id = t->context.device_id;
     if (scache_[device_id] == nil) {
@@ -197,7 +197,7 @@ class MetalWrappedFunc {
     }
     if (num_pack_args_ != 0) {
       [encoder setBytes:pack_args
-                 length:num_pack_args_ * sizeof(ArgUnion)
+                 length:num_pack_args_ * sizeof(ArgUnion64)
                 atIndex:num_buffer_args_];
     }
     // launch
diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h
index 54a75d6..2e7a881 100644
--- a/src/runtime/pack_args.h
+++ b/src/runtime/pack_args.h
@@ -40,7 +40,6 @@ namespace tvm {
 namespace runtime {
 /*!
  * \brief argument union type of 32bit.
- * Choose 32 bit because most GPU API do not work well with 64 bit.
  */
 union ArgUnion {
   int32_t v_int32;
@@ -48,6 +47,9 @@ union ArgUnion {
   float v_float32;
 };
 
+/*!
+ * \brief argument union type of 64 bit, for use by Vulkan and Metal runtime.
+ */
 union ArgUnion64 {
   int32_t v_int32[2];
   uint32_t v_uint32[2];

Reply via email to