This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 48992a4093 [DeviceAPI] Support "GetCurrentStream" (#16689)
48992a4093 is described below
commit 48992a4093daf59c630cfa5d47271e27aeccccc8
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Mar 9 08:11:56 2024 -0500
[DeviceAPI] Support "GetCurrentStream" (#16689)
This PR introduces a new function `GetCurrentStream`to device API,
which returns the current stream of the given device.
Meanwhile, this PR updates the "CreateStream" of CUDA to creating
a non-blocking stream, so that the execution on this stream can
overlap with the execution of other streams.
This PR also changes the `GPUCopy` of CUDA device API to always
using `cudaMemcpyAsync`.
---
include/tvm/runtime/device_api.h | 6 ++++++
src/runtime/c_runtime_api.cc | 2 ++
src/runtime/cuda/cuda_device_api.cc | 12 ++++++------
src/runtime/metal/metal_common.h | 1 +
src/runtime/metal/metal_device_api.mm | 5 +++++
src/runtime/minrpc/rpc_reference.h | 1 +
src/runtime/rocm/rocm_device_api.cc | 4 ++++
src/runtime/rpc/rpc_device_api.cc | 7 ++++++-
src/runtime/rpc/rpc_endpoint.cc | 12 ++++++++++++
src/runtime/vulkan/vulkan_device_api.cc | 2 ++
src/runtime/vulkan/vulkan_device_api.h | 1 +
web/emcc/webgpu_runtime.cc | 2 ++
12 files changed, 48 insertions(+), 7 deletions(-)
diff --git a/include/tvm/runtime/device_api.h b/include/tvm/runtime/device_api.h
index 9ff469b7c8..721990c625 100644
--- a/include/tvm/runtime/device_api.h
+++ b/include/tvm/runtime/device_api.h
@@ -176,6 +176,12 @@ class TVM_DLL DeviceAPI {
* \param stream The stream to be set.
*/
virtual void SetStream(Device dev, TVMStreamHandle stream) {}
+ /*!
+ * \brief Get the current stream
+ * \param dev The device to get stream.
+ * \return The current stream of the device.
+ */
+ virtual TVMStreamHandle GetCurrentStream(Device dev);
/*!
* \brief Synchronize 2 streams of execution.
*
diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc
index 0881eaf704..799ef116ce 100644
--- a/src/runtime/c_runtime_api.cc
+++ b/src/runtime/c_runtime_api.cc
@@ -210,6 +210,8 @@ TVMStreamHandle DeviceAPI::CreateStream(Device dev) {
return nullptr; }
void DeviceAPI::FreeStream(Device dev, TVMStreamHandle stream) {}
+TVMStreamHandle DeviceAPI::GetCurrentStream(Device dev) { return nullptr; }
+
void DeviceAPI::SyncStreamFromTo(Device dev, TVMStreamHandle event_src,
TVMStreamHandle event_dst) {
}
diff --git a/src/runtime/cuda/cuda_device_api.cc
b/src/runtime/cuda/cuda_device_api.cc
index dcc7276bbf..a599d95f33 100644
--- a/src/runtime/cuda/cuda_device_api.cc
+++ b/src/runtime/cuda/cuda_device_api.cc
@@ -195,7 +195,7 @@ class CUDADeviceAPI final : public DeviceAPI {
TVMStreamHandle CreateStream(Device dev) {
CUDA_CALL(cudaSetDevice(dev.device_id));
cudaStream_t retval;
- CUDA_CALL(cudaStreamCreate(&retval));
+ CUDA_CALL(cudaStreamCreateWithFlags(&retval, cudaStreamNonBlocking));
return static_cast<TVMStreamHandle>(retval);
}
@@ -225,6 +225,10 @@ class CUDADeviceAPI final : public DeviceAPI {
CUDAThreadEntry::ThreadLocal()->stream = static_cast<cudaStream_t>(stream);
}
+ TVMStreamHandle GetCurrentStream(Device dev) final {
+ return
static_cast<TVMStreamHandle>(CUDAThreadEntry::ThreadLocal()->stream);
+ }
+
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}
@@ -243,11 +247,7 @@ class CUDADeviceAPI final : public DeviceAPI {
private:
static void GPUCopy(const void* from, void* to, size_t size, cudaMemcpyKind
kind,
cudaStream_t stream) {
- if (stream != nullptr) {
- CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
- } else {
- CUDA_CALL(cudaMemcpy(to, from, size, kind));
- }
+ CUDA_CALL(cudaMemcpyAsync(to, from, size, kind, stream));
}
};
diff --git a/src/runtime/metal/metal_common.h b/src/runtime/metal/metal_common.h
index d9154e0f79..dc7b344800 100644
--- a/src/runtime/metal/metal_common.h
+++ b/src/runtime/metal/metal_common.h
@@ -155,6 +155,7 @@ class MetalWorkspace final : public DeviceAPI {
void FreeStream(Device dev, TVMStreamHandle stream) final;
void StreamSync(Device dev, TVMStreamHandle stream) final;
void SetStream(Device dev, TVMStreamHandle stream) final;
+ TVMStreamHandle GetCurrentStream(Device dev) final;
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final;
void FreeWorkspace(Device dev, void* data) final;
void ReinitializeDefaultStreams();
diff --git a/src/runtime/metal/metal_device_api.mm
b/src/runtime/metal/metal_device_api.mm
index e3853ef6d6..3b01bc65b1 100644
--- a/src/runtime/metal/metal_device_api.mm
+++ b/src/runtime/metal/metal_device_api.mm
@@ -312,6 +312,11 @@ void MetalWorkspace::SetStream(Device dev, TVMStreamHandle
stream) {
MetalThreadEntry::ThreadLocal()->stream[dev.device_id] = stream;
}
+TVMStreamHandle MetalWorkspace::GetCurrentStream(Device dev) {
+ ICHECK_LT(dev.device_id, devices.size()) << "Invalid device id " <<
dev.device_id;
+ return MetalThreadEntry::ThreadLocal()->stream[dev.device_id];
+}
+
void* MetalWorkspace::AllocWorkspace(Device dev, size_t size, DLDataType
type_hint) {
return MetalThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}
diff --git a/src/runtime/minrpc/rpc_reference.h
b/src/runtime/minrpc/rpc_reference.h
index 732b017e44..d08dadb02b 100644
--- a/src/runtime/minrpc/rpc_reference.h
+++ b/src/runtime/minrpc/rpc_reference.h
@@ -69,6 +69,7 @@ enum class RPCCode : int {
kDevCreateStream,
kDevFreeStream,
kDevSetStream,
+ kDevGetCurrentStream,
};
/*!
diff --git a/src/runtime/rocm/rocm_device_api.cc
b/src/runtime/rocm/rocm_device_api.cc
index 50dede05a9..ffc8d5a805 100644
--- a/src/runtime/rocm/rocm_device_api.cc
+++ b/src/runtime/rocm/rocm_device_api.cc
@@ -186,6 +186,10 @@ class ROCMDeviceAPI final : public DeviceAPI {
ROCMThreadEntry::ThreadLocal()->stream = static_cast<hipStream_t>(stream);
}
+ TVMStreamHandle GetCurrentStream(Device dev) final {
+ return
static_cast<TVMStreamHandle>(ROCMThreadEntry::ThreadLocal()->stream);
+ }
+
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
return ROCMThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}
diff --git a/src/runtime/rpc/rpc_device_api.cc
b/src/runtime/rpc/rpc_device_api.cc
index a2d1ac17ef..a5c8541dc0 100644
--- a/src/runtime/rpc/rpc_device_api.cc
+++ b/src/runtime/rpc/rpc_device_api.cc
@@ -126,11 +126,16 @@ class RPCDeviceAPI final : public DeviceAPI {
GetSess(dev)->GetDeviceAPI(remote_dev)->StreamSync(remote_dev, stream);
}
- void SetStream(Device dev, TVMStreamHandle stream) {
+ void SetStream(Device dev, TVMStreamHandle stream) final {
auto remote_dev = RemoveRPCSessionMask(dev);
GetSess(dev)->GetDeviceAPI(remote_dev)->SetStream(remote_dev, stream);
}
+ TVMStreamHandle GetCurrentStream(Device dev) final {
+ auto remote_dev = RemoveRPCSessionMask(dev);
+ return
GetSess(dev)->GetDeviceAPI(remote_dev)->GetCurrentStream(remote_dev);
+ }
+
protected:
void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t
to_offset,
size_t num_bytes, Device dev_from, Device dev_to,
DLDataType type_hint,
diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc
index a0c732a9c8..b4f455cc18 100644
--- a/src/runtime/rpc/rpc_endpoint.cc
+++ b/src/runtime/rpc/rpc_endpoint.cc
@@ -1006,6 +1006,11 @@ void RPCDevSetStream(RPCSession* handler, TVMArgs args,
TVMRetValue* rv) {
handler->GetDeviceAPI(dev)->SetStream(dev, stream);
}
+void RPCDevGetCurrentStream(RPCSession* handler, TVMArgs args, TVMRetValue*
rv) {
+ Device dev = args[0];
+ *rv = handler->GetDeviceAPI(dev)->GetCurrentStream(dev);
+}
+
void RPCEndpoint::EventHandler::HandleSyscall(RPCCode code) {
// Event handler sit at clean state at this point.
switch (code) {
@@ -1043,6 +1048,9 @@ void RPCEndpoint::EventHandler::HandleSyscall(RPCCode
code) {
case RPCCode::kDevSetStream:
SysCallHandler(RPCDevSetStream);
break;
+ case RPCCode::kDevGetCurrentStream:
+ SysCallHandler(RPCDevGetCurrentStream);
+ break;
case RPCCode::kCopyAmongRemote:
SysCallHandler(RPCCopyAmongRemote);
break;
@@ -1188,6 +1196,10 @@ class RPCClientSession : public RPCSession, public
DeviceAPI {
endpoint_->SysCallRemote(RPCCode::kDevSetStream, dev, stream);
}
+ TVMStreamHandle GetCurrentStream(Device dev) final {
+ return endpoint_->SysCallRemote(RPCCode::kDevGetCurrentStream, dev);
+ }
+
DeviceAPI* GetDeviceAPI(Device dev, bool allow_missing) final { return this;
}
bool IsLocalSession() const final { return false; }
diff --git a/src/runtime/vulkan/vulkan_device_api.cc
b/src/runtime/vulkan/vulkan_device_api.cc
index e02c9304e1..18a40bf54f 100644
--- a/src/runtime/vulkan/vulkan_device_api.cc
+++ b/src/runtime/vulkan/vulkan_device_api.cc
@@ -327,6 +327,8 @@ void VulkanDeviceAPI::SetStream(Device dev, TVMStreamHandle
stream) {
ICHECK_EQ(stream, static_cast<void*>(nullptr));
}
+TVMStreamHandle VulkanDeviceAPI::GetCurrentStream(Device dev) { return
nullptr; }
+
void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset,
void* to,
size_t to_offset, size_t size, Device
dev_from, Device dev_to,
DLDataType type_hint, TVMStreamHandle
stream) {
diff --git a/src/runtime/vulkan/vulkan_device_api.h
b/src/runtime/vulkan/vulkan_device_api.h
index 851fede306..35100ee627 100644
--- a/src/runtime/vulkan/vulkan_device_api.h
+++ b/src/runtime/vulkan/vulkan_device_api.h
@@ -62,6 +62,7 @@ class VulkanDeviceAPI final : public DeviceAPI {
void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle
event_dst) final;
void StreamSync(Device dev, TVMStreamHandle stream) final;
void SetStream(Device dev, TVMStreamHandle stream) final;
+ TVMStreamHandle GetCurrentStream(Device dev) final;
protected:
void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t
to_offset, size_t size,
diff --git a/web/emcc/webgpu_runtime.cc b/web/emcc/webgpu_runtime.cc
index 957c8752ff..ce2a7cadb6 100644
--- a/web/emcc/webgpu_runtime.cc
+++ b/web/emcc/webgpu_runtime.cc
@@ -116,6 +116,8 @@ class WebGPUDeviceAPI : public DeviceAPI {
void SetStream(Device dev, TVMStreamHandle stream) final { LOG(FATAL) <<
"Not implemented"; }
+ TVMStreamHandle GetCurrentStream(Device dev) final { LOG(FATAL) << "Not
implemented"; }
+
void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final {
return WebGPUThreadEntry::ThreadLocal()->pool.AllocWorkspace(dev, size);
}