This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 820f1b617a [Runtime] Compatibility with dmlc::Stream API changes (#16998) 820f1b617a is described below commit 820f1b617a4f8ccf196803c5e48a4f155c929c4a Author: Eric Lunderberg <lunderb...@users.noreply.github.com> AuthorDate: Thu May 30 11:41:03 2024 -0500 [Runtime] Compatibility with dmlc::Stream API changes (#16998) * [Runtime] Compatibility with dmlc::Stream API changes This commit updates TVM implementations of `dmlc::Stream`. With https://github.com/dmlc/dmlc-core/pull/686, this API now requires the `Write` method to return the number of bytes written. This change allows partial writes to be correctly handled. * Update dmlc-core version * lint fix --- 3rdparty/dmlc-core | 2 +- src/runtime/disco/process_session.cc | 3 ++- src/runtime/disco/threaded_session.cc | 3 ++- src/runtime/file_utils.h | 8 ++++++-- src/runtime/rpc/rpc_endpoint.cc | 8 ++++++-- src/runtime/rpc/rpc_socket_impl.cc | 7 ++----- src/support/base64.h | 5 +++-- src/support/pipe.h | 24 +++++++++++------------- 8 files changed, 33 insertions(+), 27 deletions(-) diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 09511cf9fe..3031e4a61a 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 09511cf9fe5ff103900a5eafb50870dc84cc17c8 +Subproject commit 3031e4a61a98f49f07a42cfdec6242340fb2fd8c diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index b507758777..179010db8a 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -113,10 +113,11 @@ class DiscoPipeMessageQueue : private dmlc::Stream, private DiscoProtocol<DiscoP return size; } - void Write(const void* data, size_t size) final { + size_t Write(const void* data, size_t size) final { size_t cur_size = write_buffer_.size(); write_buffer_.resize(cur_size + size); std::memcpy(write_buffer_.data() + cur_size, data, size); + return size; } using dmlc::Stream::Read; diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index 7a76a45ed5..22f906b809 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -96,10 +96,11 @@ class DiscoThreadedMessageQueue : private dmlc::Stream, return size; } - void Write(const void* data, size_t size) final { + size_t Write(const void* data, size_t size) final { size_t cur_size = write_buffer_.size(); write_buffer_.resize(cur_size + size); std::memcpy(write_buffer_.data() + cur_size, data, size); + return size; } using dmlc::Stream::Read; diff --git a/src/runtime/file_utils.h b/src/runtime/file_utils.h index 20806f5ff1..0f3dc13571 100644 --- a/src/runtime/file_utils.h +++ b/src/runtime/file_utils.h @@ -149,10 +149,14 @@ struct SimpleBinaryFileStream : public dmlc::Stream { CHECK(fp_ != nullptr) << "File is closed"; return std::fread(ptr, 1, size, fp_); } - virtual void Write(const void* ptr, size_t size) { + virtual size_t Write(const void* ptr, size_t size) { CHECK(!read_) << "File opened in read-mode, cannot write."; CHECK(fp_ != nullptr) << "File is closed"; - CHECK(std::fwrite(ptr, 1, size, fp_) == size) << "SimpleBinaryFileStream.Write incomplete"; + size_t nwrite = std::fwrite(ptr, 1, size, fp_); + int err = std::ferror(fp_); + + CHECK_EQ(err, 0) << "SimpleBinaryFileStream.Write incomplete: " << std::strerror(err); + return nwrite; } inline void Close(void) { if (fp_ != nullptr) { diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index b4f455cc18..5d04ee8387 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -666,8 +666,12 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { pending_request_bytes_ -= size; return size; } - // wriite the data to the channel. - void Write(const void* data, size_t size) final { writer_->Write(data, size); } + // write the data to the channel. + size_t Write(const void* data, size_t size) final { + writer_->Write(data, size); + return size; + } + // Number of pending bytes requests size_t pending_request_bytes_{0}; // The ring buffer to read data from. diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 1d0b5d5470..6882ba4ded 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -159,11 +159,8 @@ class SimpleSockHandler : public dmlc::Stream { // Internal supporting. // Override methods that inherited from dmlc::Stream. private: - size_t Read(void* data, size_t size) final { - ICHECK_EQ(sock_.RecvAll(data, size), size); - return size; - } - void Write(const void* data, size_t size) final { ICHECK_EQ(sock_.SendAll(data, size), size); } + size_t Read(void* data, size_t size) final { return sock_.Recv(data, size); } + size_t Write(const void* data, size_t size) final { return sock_.Send(data, size); } // Things of current class. private: diff --git a/src/support/base64.h b/src/support/base64.h index aba4197bce..2bfc42c27f 100644 --- a/src/support/base64.h +++ b/src/support/base64.h @@ -206,7 +206,7 @@ class Base64InStream : public dmlc::Stream { } return size - tlen; } - virtual void Write(const void* ptr, size_t size) { + size_t Write(const void* ptr, size_t size) final { LOG(FATAL) << "Base64InStream do not support write"; } @@ -229,7 +229,7 @@ class Base64OutStream : public dmlc::Stream { using dmlc::Stream::Write; - void Write(const void* ptr, size_t size) final { + size_t Write(const void* ptr, size_t size) final { using base64::EncodeTable; size_t tlen = size; const unsigned char* cptr = static_cast<const unsigned char*>(ptr); @@ -247,6 +247,7 @@ class Base64OutStream : public dmlc::Stream { buf__top_ = 0; } } + return size; } virtual size_t Read(void* ptr, size_t size) { LOG(FATAL) << "Base64OutStream do not support read"; diff --git a/src/support/pipe.h b/src/support/pipe.h index 7251a6f14a..9d5aa1e486 100644 --- a/src/support/pipe.h +++ b/src/support/pipe.h @@ -112,8 +112,8 @@ class Pipe : public dmlc::Stream { * \param size block size * \return the size of data read */ - void Write(const void* ptr, size_t size) final { - if (size == 0) return; + size_t Write(const void* ptr, size_t size) final { + if (size == 0) return 0; #ifdef _WIN32 auto fwrite = [&]() -> ssize_t { DWORD nwrite; @@ -124,18 +124,16 @@ class Pipe : public dmlc::Stream { DWORD nwrite = static_cast<DWORD>(RetryCallOnEINTR(fwrite, GetLastErrorCode)); ICHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " << GetLastError(); #else - while (size) { - ssize_t nwrite = - RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, GetLastErrorCode); - ICHECK_NE(nwrite, -1) << "Write Error: " << strerror(errno); - - ICHECK_GT(nwrite, 0) << "Was unable to write any data to pipe"; - ICHECK_LE(nwrite, size) << "Wrote " << nwrite << " bytes, " - << "but only expected to write " << size << " bytes"; - size -= nwrite; - ptr = static_cast<const char*>(ptr) + nwrite; - } + ssize_t nwrite = + RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, GetLastErrorCode); + ICHECK_NE(nwrite, -1) << "Write Error: " << strerror(errno); + + ICHECK_LE(nwrite, size) << "Wrote " << nwrite << " bytes, " + << "but only expected to write " << size << " bytes"; + #endif + + return nwrite; } /*! * \brief Flush the pipe;