This is an automated email from the ASF dual-hosted git repository. junrushao pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push: new 7c35267756 [Fix] add TVM_DLL to disco functions (#16258) 7c35267756 is described below commit 7c352677568df0f12c49a4b5b8864b11fb37701f Author: Lesheng Jin <34279105+lesheng...@users.noreply.github.com> AuthorDate: Mon Dec 18 15:32:52 2023 +0800 [Fix] add TVM_DLL to disco functions (#16258) --- include/tvm/runtime/disco/builtin.h | 4 ++-- include/tvm/runtime/disco/disco_worker.h | 2 +- include/tvm/runtime/relax_vm/ndarray_cache_support.h | 10 +++++----- src/runtime/disco/builtin.cc | 4 ++-- src/runtime/disco/disco_worker.cc | 2 +- src/runtime/relax_vm/ndarray_cache_support.cc | 11 ++++++----- 6 files changed, 17 insertions(+), 16 deletions(-) diff --git a/include/tvm/runtime/disco/builtin.h b/include/tvm/runtime/disco/builtin.h index 3847aef3f2..512059b31b 100644 --- a/include/tvm/runtime/disco/builtin.h +++ b/include/tvm/runtime/disco/builtin.h @@ -89,14 +89,14 @@ void AllGather(NDArray send, NDArray recv); * \param send The buffer to be broadcasted * \param recv The buffer receives the broadcasted array */ -void BroadcastFromWorker0(NDArray send, NDArray recv); +TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv); /*! * \brief Perform a scatter operation from worker-0, chunking the given buffer into equal parts. * \param send For worker-0, it must be provided, and otherwise, the buffer must be None. * The buffer will be divided into equal parts and sent to each worker accordingly. * \param recv The receiving buffer, which must not be None. */ -void ScatterFromWorker0(Optional<NDArray> send, NDArray recv); +TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv); /*! * \brief Perform a gather operation to worker-0. * \param send The sending buffer, which must not be None. diff --git a/include/tvm/runtime/disco/disco_worker.h b/include/tvm/runtime/disco/disco_worker.h index 0c666150d4..14f8f23807 100644 --- a/include/tvm/runtime/disco/disco_worker.h +++ b/include/tvm/runtime/disco/disco_worker.h @@ -60,7 +60,7 @@ class DiscoWorker { /*! \brief Main loop of the worker */ void MainLoop(); /*! \brief Get the worker instance on the current thread */ - static DiscoWorker* ThreadLocal(); + TVM_DLL static DiscoWorker* ThreadLocal(); /*! \brief Set the specific register to a specific value */ void SetRegister(int reg_id, TVMArgValue value); diff --git a/include/tvm/runtime/relax_vm/ndarray_cache_support.h b/include/tvm/runtime/relax_vm/ndarray_cache_support.h index 3d8b639ee4..584da8f0ca 100644 --- a/include/tvm/runtime/relax_vm/ndarray_cache_support.h +++ b/include/tvm/runtime/relax_vm/ndarray_cache_support.h @@ -63,10 +63,10 @@ struct NDArrayCacheMetadata { }; /*! \brief Load a FileRecord into memory */ - Array<NDArray> Load(Device device, // - const std::string& path_prefix, // - std::string* raw_data_buffer, // - Optional<NDArray>* staging_buffer = nullptr) const; + TVM_DLL Array<NDArray> Load(Device device, // + const std::string& path_prefix, // + std::string* raw_data_buffer, // + Optional<NDArray>* staging_buffer = nullptr) const; /*! \brief Relative path to the bin file */ std::string data_path; @@ -83,7 +83,7 @@ struct NDArrayCacheMetadata { std::string path; /*! \brief Load the metadata from a specific directory */ - static NDArrayCacheMetadata Load(const std::string& path); + TVM_DLL static NDArrayCacheMetadata Load(const std::string& path); /*! \brief Load the metadata from a given JSON string */ static NDArrayCacheMetadata LoadFromStr(const std::string& json_str, const std::string& path); }; diff --git a/src/runtime/disco/builtin.cc b/src/runtime/disco/builtin.cc index 51fe4c13fc..911fdaae3d 100644 --- a/src/runtime/disco/builtin.cc +++ b/src/runtime/disco/builtin.cc @@ -85,11 +85,11 @@ void AllReduce(NDArray send, ReduceKind reduce_kind, NDArray recv) { void AllGather(NDArray send, NDArray recv) { GetCCLFunc("allgather")(send, recv); } -void BroadcastFromWorker0(NDArray send, NDArray recv) { +TVM_DLL void BroadcastFromWorker0(NDArray send, NDArray recv) { GetCCLFunc("broadcast_from_worker0")(send, recv); } -void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) { +TVM_DLL void ScatterFromWorker0(Optional<NDArray> send, NDArray recv) { GetCCLFunc("scatter_from_worker0")(send, recv); } diff --git a/src/runtime/disco/disco_worker.cc b/src/runtime/disco/disco_worker.cc index d3c6d6a383..e8ba351e79 100644 --- a/src/runtime/disco/disco_worker.cc +++ b/src/runtime/disco/disco_worker.cc @@ -37,7 +37,7 @@ struct ThreadLocalDiscoWorker { } }; -DiscoWorker* DiscoWorker::ThreadLocal() { +TVM_DLL DiscoWorker* DiscoWorker::ThreadLocal() { DiscoWorker* ret = ThreadLocalDiscoWorker::Get()->worker; CHECK(ret) << "ValueError: The current thread is not a DiscoWorker thread"; return ret; diff --git a/src/runtime/relax_vm/ndarray_cache_support.cc b/src/runtime/relax_vm/ndarray_cache_support.cc index 613c70bb44..ce028f4d7d 100644 --- a/src/runtime/relax_vm/ndarray_cache_support.cc +++ b/src/runtime/relax_vm/ndarray_cache_support.cc @@ -123,7 +123,7 @@ NDArrayCacheMetadata NDArrayCacheMetadata::LoadFromStr(const std::string& json_s return result; } -NDArrayCacheMetadata NDArrayCacheMetadata::Load(const std::string& path) { +TVM_DLL NDArrayCacheMetadata NDArrayCacheMetadata::Load(const std::string& path) { picojson::value json_info; { std::string json_str; @@ -183,10 +183,11 @@ NDArray NDArrayCacheMetadata::FileRecord::ParamRecord::Load( return arr; } -Array<NDArray> NDArrayCacheMetadata::FileRecord::Load(Device device, - const std::string& path_prefix, // - std::string* raw_data_buffer, // - Optional<NDArray>* staging_buffer) const { +TVM_DLL Array<NDArray> NDArrayCacheMetadata::FileRecord::Load( + Device device, + const std::string& path_prefix, // + std::string* raw_data_buffer, // + Optional<NDArray>* staging_buffer) const { LoadBinaryFromFile(path_prefix + "/" + this->data_path, raw_data_buffer); CHECK_EQ(this->format, "raw-shard") << "ValueError: Only `raw-shard` format is supported"; CHECK_EQ(this->nbytes, raw_data_buffer->length())