This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 08b32a7976 [Runtime][ROCm] Enable ROCm host memory support (#17037) 08b32a7976 is described below commit 08b32a797642515b0b263ead292af6962fea0cf4 Author: Ruihang Lai <ruiha...@cs.cmu.edu> AuthorDate: Thu May 30 07:28:26 2024 -0400 [Runtime][ROCm] Enable ROCm host memory support (#17037) This PR enables the ROCMHost memory support in ROCm device API. --- src/runtime/ndarray.cc | 3 ++- src/runtime/rocm/rocm_device_api.cc | 40 ++++++++++++++++++++++++++++++++----- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index c2efa79c0c..c2cf5f388a 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -316,7 +316,8 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* to, TVMStreamHandle str ICHECK(from->device.device_type == to->device.device_type || from->device.device_type == kDLCPU || to->device.device_type == kDLCPU || from->device.device_type == kDLCUDAHost || - to->device.device_type == kDLCUDAHost) + to->device.device_type == kDLCUDAHost || from->device.device_type == kDLROCMHost || + to->device.device_type == kDLROCMHost) << "Can not copy across different device types directly. From device type: " << from->device.device_type << " to device type: " << to->device.device_type; diff --git a/src/runtime/rocm/rocm_device_api.cc b/src/runtime/rocm/rocm_device_api.cc index f3cc46f927..e2a5048ca0 100644 --- a/src/runtime/rocm/rocm_device_api.cc +++ b/src/runtime/rocm/rocm_device_api.cc @@ -144,16 +144,26 @@ class ROCMDeviceAPI final : public DeviceAPI { *rv = value; } void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final { - ROCM_CALL(hipSetDevice(dev.device_id)); ICHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes"; void* ret; - ROCM_CALL(hipMalloc(&ret, nbytes)); + if (dev.device_type == kDLROCMHost) { + VLOG(1) << "allocating " << nbytes << "bytes on host"; + ROCM_CALL(hipHostMalloc(&ret, nbytes)); + } else { + ROCM_CALL(hipSetDevice(dev.device_id)); + VLOG(1) << "allocating " << nbytes << " bytes on device"; + ROCM_CALL(hipMalloc(&ret, nbytes)); + } return ret; } void FreeDataSpace(Device dev, void* ptr) final { - ROCM_CALL(hipSetDevice(dev.device_id)); - ROCM_CALL(hipFree(ptr)); + if (dev.device_type == kDLROCMHost) { + ROCM_CALL(hipHostFree(ptr)); + } else { + ROCM_CALL(hipSetDevice(dev.device_id)); + ROCM_CALL(hipFree(ptr)); + } } void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, @@ -162,6 +172,21 @@ class ROCMDeviceAPI final : public DeviceAPI { hipStream_t hip_stream = static_cast<hipStream_t>(stream); from = static_cast<const char*>(from) + from_offset; to = static_cast<char*>(to) + to_offset; + + if (dev_from.device_type == kDLROCMHost) { + dev_from.device_type = kDLCPU; + } + + if (dev_to.device_type == kDLROCMHost) { + dev_to.device_type = kDLCPU; + } + + // In case there is a copy from host mem to host mem */ + if (dev_to.device_type == kDLCPU && dev_from.device_type == kDLCPU) { + memcpy(to, from, size); + return; + } + if (dev_from.device_type == kDLROCM && dev_to.device_type == kDLROCM) { ROCM_CALL(hipSetDevice(dev_from.device_id)); if (dev_from.device_id == dev_to.device_id) { @@ -210,7 +235,7 @@ class ROCMDeviceAPI final : public DeviceAPI { private: static void GPUCopy(const void* from, void* to, size_t size, hipMemcpyKind kind, hipStream_t stream) { - if (stream != 0) { + if (stream != nullptr) { ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream)); } else { ROCM_CALL(hipMemcpy(to, from, size, kind)); @@ -229,6 +254,11 @@ TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs args, TVMRetValue* rv *rv = static_cast<void*>(ptr); }); +TVM_REGISTER_GLOBAL("device_api.rocm_host").set_body([](TVMArgs args, TVMRetValue* rv) { + DeviceAPI* ptr = ROCMDeviceAPI::Global(); + *rv = static_cast<void*>(ptr); +}); + class ROCMTimerNode : public TimerNode { public: virtual void Start() {