This is an automated email from the ASF dual-hosted git repository. wuwei pushed a commit to branch vk-i64 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit ad075025a61bed23df0bf5f6c78853b5204725f4 Author: Masahiro Masuda <masahi...@gmail.com> AuthorDate: Wed Mar 3 07:37:52 2021 +0900 introduce ArgUnion64 --- src/runtime/pack_args.h | 26 +++++++++++++++++++------- src/runtime/vulkan/vulkan.cc | 12 ++++++------ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/src/runtime/pack_args.h b/src/runtime/pack_args.h index 45cde22..54a75d6 100644 --- a/src/runtime/pack_args.h +++ b/src/runtime/pack_args.h @@ -47,6 +47,15 @@ union ArgUnion { uint32_t v_uint32; float v_float32; }; + +union ArgUnion64 { + int32_t v_int32[2]; + uint32_t v_uint32[2]; + float v_float32[2]; + int64_t v_int64; + uint64_t v_uint64; + double v_float64; +}; /*! * \brief Create a packed function from void addr types. * @@ -177,25 +186,28 @@ template <int N, typename F> inline PackedFunc PackFuncNonBufferArg_(F f, int base, const std::vector<ArgConvertCode>& codes) { int num_args = static_cast<int>(codes.size()); auto ret = [f, codes, base, num_args](TVMArgs args, TVMRetValue* ret) { - TempArray<ArgUnion, N> holder_(num_args); - ArgUnion* holder = holder_.data(); + TempArray<ArgUnion64, N> holder_(num_args); + ArgUnion64* holder = holder_.data(); for (int i = 0; i < num_args; ++i) { switch (codes[i]) { - case INT64_TO_INT64: + case INT64_TO_INT64: { + holder[i].v_int64 = args.values[base + i].v_int64; + break; + } case FLOAT64_TO_FLOAT64: { - LOG(FATAL) << "Do not support 64bit argument to device function"; + holder[i].v_float64 = args.values[base + i].v_float64; break; } case INT64_TO_INT32: { - holder[i].v_int32 = static_cast<int32_t>(args.values[base + i].v_int64); + holder[i].v_int32[0] = static_cast<int32_t>(args.values[base + i].v_int64); break; } case INT64_TO_UINT32: { - holder[i].v_uint32 = static_cast<uint32_t>(args.values[base + i].v_int64); + holder[i].v_uint32[0] = static_cast<uint32_t>(args.values[base + i].v_int64); break; } case FLOAT64_TO_FLOAT32: { - holder[i].v_float32 = static_cast<float>(args.values[base + i].v_float64); + holder[i].v_float32[0] = static_cast<float>(args.values[base + i].v_float64); break; } case HANDLE_TO_HANDLE: { diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index f40fd80..4eb3481 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -711,7 +711,7 @@ class VulkanWrappedFunc { thread_axis_cfg_.Init(num_buffer_args + num_pack_args, thread_axis_tags); } - void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const; + void operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const; private: // internal module @@ -875,7 +875,7 @@ class VulkanModuleNode final : public runtime::ModuleNode { VkPushConstantRange crange; crange.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT; crange.offset = 0; - crange.size = sizeof(ArgUnion) * num_pack_args; + crange.size = sizeof(ArgUnion64) * num_pack_args; VkPipelineLayoutCreateInfo playout_cinfo; playout_cinfo.sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO; @@ -1046,7 +1046,7 @@ VulkanStream* VulkanThreadEntry::Stream(size_t device_id) { return streams_[device_id].get(); } -void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion* pack_args) const { +void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion64* pack_args) const { int device_id = VulkanThreadEntry::ThreadLocal()->ctx.device_id; ICHECK_LT(device_id, kVulkanMaxNumDevice); const auto& vctx = VulkanDeviceAPI::Global()->context(device_id); @@ -1075,7 +1075,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion descriptor_buffers.data()); if (num_pack_args_ != 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, - VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion), + VK_SHADER_STAGE_COMPUTE_BIT, 0, num_pack_args_ * sizeof(ArgUnion64), pack_args); } vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); @@ -1093,7 +1093,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion } // Otherwise, the more expensive deferred path. - std::vector<ArgUnion> pack_args_storage(pack_args, pack_args + num_pack_args_); + std::vector<ArgUnion64> pack_args_storage(pack_args, pack_args + num_pack_args_); const auto& deferred_initializer = [&vctx, pipeline, descriptor_buffers]() { std::vector<VkWriteDescriptorSet> write_descriptor_sets; write_descriptor_sets.resize(descriptor_buffers.size()); @@ -1119,7 +1119,7 @@ void VulkanWrappedFunc::operator()(TVMArgs args, TVMRetValue* rv, const ArgUnion nullptr); if (pack_args_storage.size() != 0) { vkCmdPushConstants(state->cmd_buffer_, pipeline->pipeline_layout, VK_SHADER_STAGE_COMPUTE_BIT, - 0, pack_args_storage.size() * sizeof(ArgUnion), pack_args_storage.data()); + 0, pack_args_storage.size() * sizeof(ArgUnion64), pack_args_storage.data()); } vkCmdDispatch(state->cmd_buffer_, wl.grid_dim(0), wl.grid_dim(1), wl.grid_dim(2)); VkMemoryBarrier barrier_info;