This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 820f1b617a [Runtime] Compatibility with dmlc::Stream API changes 
(#16998)
820f1b617a is described below

commit 820f1b617a4f8ccf196803c5e48a4f155c929c4a
Author: Eric Lunderberg <lunderb...@users.noreply.github.com>
AuthorDate: Thu May 30 11:41:03 2024 -0500

    [Runtime] Compatibility with dmlc::Stream API changes (#16998)
    
    * [Runtime] Compatibility with dmlc::Stream API changes
    
    This commit updates TVM implementations of `dmlc::Stream`.  With
    https://github.com/dmlc/dmlc-core/pull/686, this API now requires
    the `Write` method to return the number of bytes written.  This change
    allows partial writes to be correctly handled.
    
    * Update dmlc-core version
    
    * lint fix
---
 3rdparty/dmlc-core                    |  2 +-
 src/runtime/disco/process_session.cc  |  3 ++-
 src/runtime/disco/threaded_session.cc |  3 ++-
 src/runtime/file_utils.h              |  8 ++++++--
 src/runtime/rpc/rpc_endpoint.cc       |  8 ++++++--
 src/runtime/rpc/rpc_socket_impl.cc    |  7 ++-----
 src/support/base64.h                  |  5 +++--
 src/support/pipe.h                    | 24 +++++++++++-------------
 8 files changed, 33 insertions(+), 27 deletions(-)

diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core
index 09511cf9fe..3031e4a61a 160000
--- a/3rdparty/dmlc-core
+++ b/3rdparty/dmlc-core
@@ -1 +1 @@
-Subproject commit 09511cf9fe5ff103900a5eafb50870dc84cc17c8
+Subproject commit 3031e4a61a98f49f07a42cfdec6242340fb2fd8c
diff --git a/src/runtime/disco/process_session.cc 
b/src/runtime/disco/process_session.cc
index b507758777..179010db8a 100644
--- a/src/runtime/disco/process_session.cc
+++ b/src/runtime/disco/process_session.cc
@@ -113,10 +113,11 @@ class DiscoPipeMessageQueue : private dmlc::Stream, 
private DiscoProtocol<DiscoP
     return size;
   }
 
-  void Write(const void* data, size_t size) final {
+  size_t Write(const void* data, size_t size) final {
     size_t cur_size = write_buffer_.size();
     write_buffer_.resize(cur_size + size);
     std::memcpy(write_buffer_.data() + cur_size, data, size);
+    return size;
   }
 
   using dmlc::Stream::Read;
diff --git a/src/runtime/disco/threaded_session.cc 
b/src/runtime/disco/threaded_session.cc
index 7a76a45ed5..22f906b809 100644
--- a/src/runtime/disco/threaded_session.cc
+++ b/src/runtime/disco/threaded_session.cc
@@ -96,10 +96,11 @@ class DiscoThreadedMessageQueue : private dmlc::Stream,
     return size;
   }
 
-  void Write(const void* data, size_t size) final {
+  size_t Write(const void* data, size_t size) final {
     size_t cur_size = write_buffer_.size();
     write_buffer_.resize(cur_size + size);
     std::memcpy(write_buffer_.data() + cur_size, data, size);
+    return size;
   }
 
   using dmlc::Stream::Read;
diff --git a/src/runtime/file_utils.h b/src/runtime/file_utils.h
index 20806f5ff1..0f3dc13571 100644
--- a/src/runtime/file_utils.h
+++ b/src/runtime/file_utils.h
@@ -149,10 +149,14 @@ struct SimpleBinaryFileStream : public dmlc::Stream {
     CHECK(fp_ != nullptr) << "File is closed";
     return std::fread(ptr, 1, size, fp_);
   }
-  virtual void Write(const void* ptr, size_t size) {
+  virtual size_t Write(const void* ptr, size_t size) {
     CHECK(!read_) << "File opened in read-mode, cannot write.";
     CHECK(fp_ != nullptr) << "File is closed";
-    CHECK(std::fwrite(ptr, 1, size, fp_) == size) << 
"SimpleBinaryFileStream.Write incomplete";
+    size_t nwrite = std::fwrite(ptr, 1, size, fp_);
+    int err = std::ferror(fp_);
+
+    CHECK_EQ(err, 0) << "SimpleBinaryFileStream.Write incomplete: " << 
std::strerror(err);
+    return nwrite;
   }
   inline void Close(void) {
     if (fp_ != nullptr) {
diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc
index b4f455cc18..5d04ee8387 100644
--- a/src/runtime/rpc/rpc_endpoint.cc
+++ b/src/runtime/rpc/rpc_endpoint.cc
@@ -666,8 +666,12 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
     pending_request_bytes_ -= size;
     return size;
   }
-  // wriite the data to the channel.
-  void Write(const void* data, size_t size) final { writer_->Write(data, 
size); }
+  // write the data to the channel.
+  size_t Write(const void* data, size_t size) final {
+    writer_->Write(data, size);
+    return size;
+  }
+
   // Number of pending bytes requests
   size_t pending_request_bytes_{0};
   // The ring buffer to read data from.
diff --git a/src/runtime/rpc/rpc_socket_impl.cc 
b/src/runtime/rpc/rpc_socket_impl.cc
index 1d0b5d5470..6882ba4ded 100644
--- a/src/runtime/rpc/rpc_socket_impl.cc
+++ b/src/runtime/rpc/rpc_socket_impl.cc
@@ -159,11 +159,8 @@ class SimpleSockHandler : public dmlc::Stream {
   // Internal supporting.
   // Override methods that inherited from dmlc::Stream.
  private:
-  size_t Read(void* data, size_t size) final {
-    ICHECK_EQ(sock_.RecvAll(data, size), size);
-    return size;
-  }
-  void Write(const void* data, size_t size) final { 
ICHECK_EQ(sock_.SendAll(data, size), size); }
+  size_t Read(void* data, size_t size) final { return sock_.Recv(data, size); }
+  size_t Write(const void* data, size_t size) final { return sock_.Send(data, 
size); }
 
   // Things of current class.
  private:
diff --git a/src/support/base64.h b/src/support/base64.h
index aba4197bce..2bfc42c27f 100644
--- a/src/support/base64.h
+++ b/src/support/base64.h
@@ -206,7 +206,7 @@ class Base64InStream : public dmlc::Stream {
     }
     return size - tlen;
   }
-  virtual void Write(const void* ptr, size_t size) {
+  size_t Write(const void* ptr, size_t size) final {
     LOG(FATAL) << "Base64InStream do not support write";
   }
 
@@ -229,7 +229,7 @@ class Base64OutStream : public dmlc::Stream {
 
   using dmlc::Stream::Write;
 
-  void Write(const void* ptr, size_t size) final {
+  size_t Write(const void* ptr, size_t size) final {
     using base64::EncodeTable;
     size_t tlen = size;
     const unsigned char* cptr = static_cast<const unsigned char*>(ptr);
@@ -247,6 +247,7 @@ class Base64OutStream : public dmlc::Stream {
         buf__top_ = 0;
       }
     }
+    return size;
   }
   virtual size_t Read(void* ptr, size_t size) {
     LOG(FATAL) << "Base64OutStream do not support read";
diff --git a/src/support/pipe.h b/src/support/pipe.h
index 7251a6f14a..9d5aa1e486 100644
--- a/src/support/pipe.h
+++ b/src/support/pipe.h
@@ -112,8 +112,8 @@ class Pipe : public dmlc::Stream {
    * \param size block size
    * \return the size of data read
    */
-  void Write(const void* ptr, size_t size) final {
-    if (size == 0) return;
+  size_t Write(const void* ptr, size_t size) final {
+    if (size == 0) return 0;
 #ifdef _WIN32
     auto fwrite = [&]() -> ssize_t {
       DWORD nwrite;
@@ -124,18 +124,16 @@ class Pipe : public dmlc::Stream {
     DWORD nwrite = static_cast<DWORD>(RetryCallOnEINTR(fwrite, 
GetLastErrorCode));
     ICHECK_EQ(static_cast<size_t>(nwrite), size) << "Write Error: " << 
GetLastError();
 #else
-    while (size) {
-      ssize_t nwrite =
-          RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, 
GetLastErrorCode);
-      ICHECK_NE(nwrite, -1) << "Write Error: " << strerror(errno);
-
-      ICHECK_GT(nwrite, 0) << "Was unable to write any data to pipe";
-      ICHECK_LE(nwrite, size) << "Wrote " << nwrite << " bytes, "
-                              << "but only expected to write " << size << " 
bytes";
-      size -= nwrite;
-      ptr = static_cast<const char*>(ptr) + nwrite;
-    }
+    ssize_t nwrite =
+        RetryCallOnEINTR([&]() { return write(handle_, ptr, size); }, 
GetLastErrorCode);
+    ICHECK_NE(nwrite, -1) << "Write Error: " << strerror(errno);
+
+    ICHECK_LE(nwrite, size) << "Wrote " << nwrite << " bytes, "
+                            << "but only expected to write " << size << " 
bytes";
+
 #endif
+
+    return nwrite;
   }
   /*!
    * \brief Flush the pipe;

Reply via email to