This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 08b32a7976 [Runtime][ROCm] Enable ROCm host memory support (#17037)
08b32a7976 is described below

commit 08b32a797642515b0b263ead292af6962fea0cf4
Author: Ruihang Lai <ruiha...@cs.cmu.edu>
AuthorDate: Thu May 30 07:28:26 2024 -0400

    [Runtime][ROCm] Enable ROCm host memory support (#17037)
    
    This PR enables the ROCMHost memory support in ROCm device API.
---
 src/runtime/ndarray.cc              |  3 ++-
 src/runtime/rocm/rocm_device_api.cc | 40 ++++++++++++++++++++++++++++++++-----
 2 files changed, 37 insertions(+), 6 deletions(-)

diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index c2efa79c0c..c2cf5f388a 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -316,7 +316,8 @@ void NDArray::CopyFromTo(const DLTensor* from, DLTensor* 
to, TVMStreamHandle str
 
   ICHECK(from->device.device_type == to->device.device_type || 
from->device.device_type == kDLCPU ||
          to->device.device_type == kDLCPU || from->device.device_type == 
kDLCUDAHost ||
-         to->device.device_type == kDLCUDAHost)
+         to->device.device_type == kDLCUDAHost || from->device.device_type == 
kDLROCMHost ||
+         to->device.device_type == kDLROCMHost)
       << "Can not copy across different device types directly. From device 
type: "
       << from->device.device_type << " to device type: " << 
to->device.device_type;
 
diff --git a/src/runtime/rocm/rocm_device_api.cc 
b/src/runtime/rocm/rocm_device_api.cc
index f3cc46f927..e2a5048ca0 100644
--- a/src/runtime/rocm/rocm_device_api.cc
+++ b/src/runtime/rocm/rocm_device_api.cc
@@ -144,16 +144,26 @@ class ROCMDeviceAPI final : public DeviceAPI {
     *rv = value;
   }
   void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType 
type_hint) final {
-    ROCM_CALL(hipSetDevice(dev.device_id));
     ICHECK_EQ(256 % alignment, 0U) << "ROCM space is aligned at 256 bytes";
     void* ret;
-    ROCM_CALL(hipMalloc(&ret, nbytes));
+    if (dev.device_type == kDLROCMHost) {
+      VLOG(1) << "allocating " << nbytes << "bytes on host";
+      ROCM_CALL(hipHostMalloc(&ret, nbytes));
+    } else {
+      ROCM_CALL(hipSetDevice(dev.device_id));
+      VLOG(1) << "allocating " << nbytes << " bytes on device";
+      ROCM_CALL(hipMalloc(&ret, nbytes));
+    }
     return ret;
   }
 
   void FreeDataSpace(Device dev, void* ptr) final {
-    ROCM_CALL(hipSetDevice(dev.device_id));
-    ROCM_CALL(hipFree(ptr));
+    if (dev.device_type == kDLROCMHost) {
+      ROCM_CALL(hipHostFree(ptr));
+    } else {
+      ROCM_CALL(hipSetDevice(dev.device_id));
+      ROCM_CALL(hipFree(ptr));
+    }
   }
 
   void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t 
to_offset, size_t size,
@@ -162,6 +172,21 @@ class ROCMDeviceAPI final : public DeviceAPI {
     hipStream_t hip_stream = static_cast<hipStream_t>(stream);
     from = static_cast<const char*>(from) + from_offset;
     to = static_cast<char*>(to) + to_offset;
+
+    if (dev_from.device_type == kDLROCMHost) {
+      dev_from.device_type = kDLCPU;
+    }
+
+    if (dev_to.device_type == kDLROCMHost) {
+      dev_to.device_type = kDLCPU;
+    }
+
+    // In case there is a copy from host mem to host mem */
+    if (dev_to.device_type == kDLCPU && dev_from.device_type == kDLCPU) {
+      memcpy(to, from, size);
+      return;
+    }
+
     if (dev_from.device_type == kDLROCM && dev_to.device_type == kDLROCM) {
       ROCM_CALL(hipSetDevice(dev_from.device_id));
       if (dev_from.device_id == dev_to.device_id) {
@@ -210,7 +235,7 @@ class ROCMDeviceAPI final : public DeviceAPI {
  private:
   static void GPUCopy(const void* from, void* to, size_t size, hipMemcpyKind 
kind,
                       hipStream_t stream) {
-    if (stream != 0) {
+    if (stream != nullptr) {
       ROCM_CALL(hipMemcpyAsync(to, from, size, kind, stream));
     } else {
       ROCM_CALL(hipMemcpy(to, from, size, kind));
@@ -229,6 +254,11 @@ TVM_REGISTER_GLOBAL("device_api.rocm").set_body([](TVMArgs 
args, TVMRetValue* rv
   *rv = static_cast<void*>(ptr);
 });
 
+TVM_REGISTER_GLOBAL("device_api.rocm_host").set_body([](TVMArgs args, 
TVMRetValue* rv) {
+  DeviceAPI* ptr = ROCMDeviceAPI::Global();
+  *rv = static_cast<void*>(ptr);
+});
+
 class ROCMTimerNode : public TimerNode {
  public:
   virtual void Start() {

Reply via email to