This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 48992a4093 [DeviceAPI] Support "GetCurrentStream" (#16689)
48992a4093 is described below

commit 48992a4093daf59c630cfa5d47271e27aeccccc8
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Mar 9 08:11:56 2024 -0500

    [DeviceAPI] Support "GetCurrentStream" (#16689)
    
    This PR introduces a new function `GetCurrentStream`to device API,
    which returns the current stream of the given device.
    
    Meanwhile, this PR updates the "CreateStream" of CUDA to creating
    a non-blocking stream, so that the execution on this stream can
    overlap with the execution of other streams.
    
    This PR also changes the `GPUCopy` of CUDA device API to always
    using `cudaMemcpyAsync`.
---
 include/tvm/runtime/device_api.h        |  6 ++++++
 src/runtime/c_runtime_api.cc            |  2 ++
 src/runtime/cuda/cuda_device_api.cc     | 12 ++++++------
 src/runtime/metal/metal_common.h        |  1 +
 src/runtime/metal/metal_device_api.mm   |  5 +++++
 src/runtime/minrpc/rpc_reference.h      |  1 +
 src/runtime/rocm/rocm_device_api.cc     |  4 ++++
 src/runtime/rpc/rpc_device_api.cc       |  7 ++++++-
 src/runtime/rpc/rpc_endpoint.cc         | 12 ++++++++++++
 src/runtime/vulkan/vulkan_device_api.cc |  2 ++
 src/runtime/vulkan/vulkan_device_api.h  |  1 +
 web/emcc/webgpu_runtime.cc              |  2 ++
 12 files changed, 48 insertions(+), 7 deletions(-)

diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index 9ff469b7c8..721990c625 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -176,6 +176,12 @@ class TVM_DLL DeviceAPI {
    * \param stream The stream to be set.
    */
   virtual void SetStream(Device dev, TVMStreamHandle stream) {}
+  /*!
+   * \brief Get the current stream
+   * \param dev The device to get stream.
+   * \return The current stream of the device.
+   */
+  virtual TVMStreamHandle GetCurrentStream(Device dev);
   /*!
    * \brief Synchronize 2 streams of execution.
    *
diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc
index 0881eaf704..799ef116ce 100644
--- a/src/runtime/c_runtime_api.cc
+++ b/src/runtime/c_runtime_api.cc
@@ -210,6 +210,8 @@ TVMStreamHandle DeviceAPI::CreateStream(Device dev) { 
return nullptr; }
 
 void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {}
 
+TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { return nullptr; }
+
 void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src, 
TVMStreamHandle event_dst) {
 }
 
diff --git a/src/runtime/cuda/cuda_device_api.cc 
b/src/runtime/cuda/cuda_device_api.cc
index dcc7276bbf..a599d95f33 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -195,7 +195,7 @@ class CUDADeviceAPI final : public DeviceAPI {
   TVMStreamHandle CreateStream(Device dev) {
     CUDA_CALL(cudaSetDevice(dev.device_id));
     cudaStream_t retval;
-    CUDA_CALL(cudaStreamCreate(&retval));
+    CUDA_CALL(cudaStreamCreateWithFlags(&retval, cudaStreamNonBlocking));
     return static_cast<TVMStreamHandle>(retval);
   }
 
@@ -225,6 +225,10 @@ class CUDADeviceAPI final : public DeviceAPI {
     CUDAThreadEntry::ThreadLocal()->stream = static_cast<cudaStream_t>(stream);
   }
 
+  TVMStreamHandle GetCurrentStream(Device dev) final {
+    return 
static_cast<TVMStreamHandle>(CUDAThreadEntry::ThreadLocal()->stream);
+  }
+
   void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
     return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
   }
@@ -243,11 +247,7 @@ class CUDADeviceAPI final : public DeviceAPI {
  private:
   static void GPUCopy(const void* from, void* to, size_t size, cudaMemcpyKind 
kind,
                       cudaStream_t stream) {
-    if (stream != nullptr) {
-      CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
-    } else {
-      CUDA_CALL(cudaMemcpy(to, from, size, kind));
-    }
+    CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
   }
 };
 
diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h
index d9154e0f79..dc7b344800 100644
--- a/src/runtime/metal/metal_common.h
+++ b/src/runtime/metal/metal_common.h
@@ -155,6 +155,7 @@ class MetalWorkspace final : public DeviceAPI {
   void FreeStream(Device dev, TVMStreamHandle stream) final;
   void StreamSync(Device dev, TVMStreamHandle stream) final;
   void SetStream(Device dev, TVMStreamHandle stream) final;
+  TVMStreamHandle GetCurrentStream(Device dev) final;
   void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
   void FreeWorkspace(Device dev, void* data) final;
   void ReinitializeDefaultStreams();
diff --git a/src/runtime/metal/metal_device_api.mm 
b/src/runtime/metal/metal_device_api.mm
index e3853ef6d6..3b01bc65b1 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -312,6 +312,11 @@ void MetalWorkspace::SetStream(Device dev, TVMStreamHandle 
stream) {
   MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = stream;
 }
 
+TVMStreamHandle MetalWorkspace::GetCurrentStream(Device dev) {
+  ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " << 
dev.device_id;
+  return MetalThreadEntry::ThreadLocal()->stream[dev.device_id];
+}
+
 void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType 
type_hint) {
   return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
 }
diff --git a/src/runtime/minrpc/rpc_reference.h 
b/src/runtime/minrpc/rpc_reference.h
index 732b017e44..d08dadb02b 100644
--- a/src/runtime/minrpc/rpc_reference.h
+++ b/src/runtime/minrpc/rpc_reference.h
@@ -69,6 +69,7 @@ enum class RPCCode : int {
   kDevCreateStream,
   kDevFreeStream,
   kDevSetStream,
+  kDevGetCurrentStream,
 };
 
 /*!
diff --git a/src/runtime/rocm/rocm_device_api.cc 
b/src/runtime/rocm/rocm_device_api.cc
index 50dede05a9..ffc8d5a805 100644
--- a/src/runtime/rocm/rocm_device_api.cc
+++ b/src/runtime/rocm/rocm_device_api.cc
@@ -186,6 +186,10 @@ class ROCMDeviceAPI final : public DeviceAPI {
     ROCMThreadEntry::ThreadLocal()->stream = static_cast<hipStream_t>(stream);
   }
 
+  TVMStreamHandle GetCurrentStream(Device dev) final {
+    return 
static_cast<TVMStreamHandle>(ROCMThreadEntry::ThreadLocal()->stream);
+  }
+
   void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
     return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
   }
diff --git a/src/runtime/rpc/rpc_device_api.cc 
b/src/runtime/rpc/rpc_device_api.cc
index a2d1ac17ef..a5c8541dc0 100644
--- a/src/runtime/rpc/rpc_device_api.cc
+++ b/src/runtime/rpc/rpc_device_api.cc
@@ -126,11 +126,16 @@ class RPCDeviceAPI final : public DeviceAPI {
     GetSess(dev)->GetDeviceAPI(remote_dev)->StreamSync(remote_dev, stream);
   }
 
-  void SetStream(Device dev, TVMStreamHandle stream) {
+  void SetStream(Device dev, TVMStreamHandle stream) final {
     auto remote_dev = RemoveRPCSessionMask(dev);
     GetSess(dev)->GetDeviceAPI(remote_dev)->SetStream(remote_dev, stream);
   }
 
+  TVMStreamHandle GetCurrentStream(Device dev) final {
+    auto remote_dev = RemoveRPCSessionMask(dev);
+    return 
GetSess(dev)->GetDeviceAPI(remote_dev)->GetCurrentStream(remote_dev);
+  }
+
  protected:
   void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t 
to_offset,
                       size_t num_bytes, Device dev_from, Device dev_to, 
DLDataType type_hint,
diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc
index a0c732a9c8..b4f455cc18 100644
--- a/src/runtime/rpc/rpc_endpoint.cc
+++ b/src/runtime/rpc/rpc_endpoint.cc
@@ -1006,6 +1006,11 @@ void RPCDevSetStream(RPCSession* handler, TVMArgs args, 
TVMRetValue* rv) {
   handler->GetDeviceAPI(dev)->SetStream(dev, stream);
 }
 
+void RPCDevGetCurrentStream(RPCSession* handler, TVMArgs args, TVMRetValue* 
rv) {
+  Device dev = args[0];
+  *rv = handler->GetDeviceAPI(dev)->GetCurrentStream(dev);
+}
+
 void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) {
   // Event handler sit at clean state at this point.
   switch (code) {
@@ -1043,6 +1048,9 @@ void RPCEndpoint::EventHandler::HandleSyscall(RPCCode 
code) {
     case RPCCode::kDevSetStream:
       SysCallHandler(RPCDevSetStream);
       break;
+    case RPCCode::kDevGetCurrentStream:
+      SysCallHandler(RPCDevGetCurrentStream);
+      break;
     case RPCCode::kCopyAmongRemote:
       SysCallHandler(RPCCopyAmongRemote);
       break;
@@ -1188,6 +1196,10 @@ class RPCClientSession : public RPCSession, public 
DeviceAPI {
     endpoint_->SysCallRemote(RPCCode::kDevSetStream, dev, stream);
   }
 
+  TVMStreamHandle GetCurrentStream(Device dev) final {
+    return endpoint_->SysCallRemote(RPCCode::kDevGetCurrentStream, dev);
+  }
+
   DeviceAPI* GetDeviceAPI(Device dev, bool allow_missing) final { return this; 
}
 
   bool IsLocalSession() const final { return false; }
diff --git a/src/runtime/vulkan/vulkan_device_api.cc 
b/src/runtime/vulkan/vulkan_device_api.cc
index e02c9304e1..18a40bf54f 100644
--- a/src/runtime/vulkan/vulkan_device_api.cc
+++ b/src/runtime/vulkan/vulkan_device_api.cc
@@ -327,6 +327,8 @@ void VulkanDeviceAPI::SetStream(Device dev, TVMStreamHandle 
stream) {
   ICHECK_EQ(stream, static_cast<void*>(nullptr));
 }
 
+TVMStreamHandle VulkanDeviceAPI::GetCurrentStream(Device dev) { return 
nullptr; }
+
 void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, 
void* to,
                                      size_t to_offset, size_t size, Device 
dev_from, Device dev_to,
                                      DLDataType type_hint, TVMStreamHandle 
stream) {
diff --git a/src/runtime/vulkan/vulkan_device_api.h 
b/src/runtime/vulkan/vulkan_device_api.h
index 851fede306..35100ee627 100644
--- a/src/runtime/vulkan/vulkan_device_api.h
+++ b/src/runtime/vulkan/vulkan_device_api.h
@@ -62,6 +62,7 @@ class VulkanDeviceAPI final : public DeviceAPI {
   void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle 
event_dst) final;
   void StreamSync(Device dev, TVMStreamHandle stream) final;
   void SetStream(Device dev, TVMStreamHandle stream) final;
+  TVMStreamHandle GetCurrentStream(Device dev) final;
 
  protected:
   void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t 
to_offset, size_t size,
diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc
index 957c8752ff..ce2a7cadb6 100644
--- a/web/emcc/webgpu_runtime.cc
+++ b/web/emcc/webgpu_runtime.cc
@@ -116,6 +116,8 @@ class WebGPUDeviceAPI : public DeviceAPI {
 
   void SetStream(Device dev, TVMStreamHandle stream) final { LOG(FATAL) << 
"Not implemented"; }
 
+  TVMStreamHandle GetCurrentStream(Device dev) final { LOG(FATAL) << "Not 
implemented"; }
+
   void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
     return WebGPUThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
   }

Reply via email to