This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 2175f6b  [RPC] Improve RPCServer AsyncIO support. (#5544)
2175f6b is described below

commit 2175f6be6d1fd414a14b73deb78808b80e2ba032
Author: Tianqi Chen <tqc...@users.noreply.github.com>
AuthorDate: Fri May 8 15:55:22 2020 -0700

    [RPC] Improve RPCServer AsyncIO support. (#5544)
    
    * [RPC] Improve RPCServer AsyncIO support.
    
    When the RPCServer is in the async IO mode, it is possible for the server
    to directly serve async function that may return its value via a callback 
in the future.
    This mode is particular useful to the web environment, where blocking is 
not an option.
    
    This PR introduces the Async support to the RPCSession, allowing the 
AsyncIO driven servers
    to serve the async functions. These functions will still be presented as 
synchronized version
    on the client side.
    
    Followup PR will refactor the web runtime to make use of this feature.
    
    * Address comments
---
 src/runtime/rpc/rpc_endpoint.cc      | 267 ++++++++++++++++++++++-------------
 src/runtime/rpc/rpc_local_session.cc |  40 +++---
 src/runtime/rpc/rpc_local_session.h  |  23 +--
 src/runtime/rpc/rpc_session.cc       |  87 ++++++++++++
 src/runtime/rpc/rpc_session.h        | 109 +++++++++++++-
 5 files changed, 397 insertions(+), 129 deletions(-)

diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc
index 8a7f11c..26f24c9 100644
--- a/src/runtime/rpc/rpc_endpoint.cc
+++ b/src/runtime/rpc/rpc_endpoint.cc
@@ -57,11 +57,13 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
   EventHandler(support::RingBuffer* reader,
                support::RingBuffer* writer,
                std::string name,
-               std::string* remote_key)
+               std::string* remote_key,
+               std::function<void()> flush_writer)
       : reader_(reader),
         writer_(writer),
         name_(name),
-        remote_key_(remote_key) {
+        remote_key_(remote_key),
+        flush_writer_(flush_writer) {
     this->Clear();
 
     if (*remote_key == "%toinit") {
@@ -109,13 +111,21 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
   /*!
    * \brief Enter the io loop until the next event.
    * \param client_mode Whether we are in the client.
+   * \param async_server_mode Whether we are in the async server mode.
    * \param setreturn The function to set the return value encoding.
    * \return The function to set return values when there is a return event.
    */
-  RPCCode HandleNextEvent(bool client_mode, RPCSession::FEncodeReturn 
setreturn) {
+  RPCCode HandleNextEvent(bool client_mode,
+                          bool async_server_mode,
+                          RPCSession::FEncodeReturn setreturn) {
     std::swap(client_mode_, client_mode);
+    std::swap(async_server_mode_, async_server_mode);
 
-    while (this->Ready()) {
+    RPCCode status = RPCCode::kNone;
+
+    while (status == RPCCode::kNone &&
+           state_ != kWaitForAsyncCallback &&
+           this->Ready()) {
       switch (state_) {
         case kInitHeader: HandleInitHeader(); break;
         case kRecvPacketNumBytes: {
@@ -133,23 +143,27 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
           this->HandleProcessPacket(setreturn);
           break;
         }
+        case kWaitForAsyncCallback: {
+          break;
+        }
         case kReturnReceived: {
           this->SwitchToState(kRecvPacketNumBytes);
-          std::swap(client_mode_, client_mode);
-          return RPCCode::kReturn;
+          status = RPCCode::kReturn;
+          break;
         }
         case kCopyAckReceived: {
-          std::swap(client_mode_, client_mode);
-          return RPCCode::kCopyAck;
+          status = RPCCode::kCopyAck;
+          break;
         }
         case kShutdownReceived: {
-          std::swap(client_mode_, client_mode);
-          return RPCCode::kShutdown;
+          status = RPCCode::kShutdown;
         }
       }
     }
+
+    std::swap(async_server_mode_, async_server_mode);
     std::swap(client_mode_, client_mode);
-    return RPCCode::kNone;
+    return status;
   }
 
   /*! \brief Clear all the states in the Handler.*/
@@ -229,6 +243,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
     kInitHeader,
     kRecvPacketNumBytes,
     kProcessPacket,
+    kWaitForAsyncCallback,
     kReturnReceived,
     kCopyAckReceived,
     kShutdownReceived
@@ -239,6 +254,8 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
   bool init_header_step_{0};
   // Whether current handler is client or server mode.
   bool client_mode_{false};
+  // Whether current handler is in the async server mode.
+  bool async_server_mode_{false};
   // Internal arena
   support::Arena arena_;
 
@@ -249,6 +266,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
       CHECK_EQ(pending_request_bytes_, 0U)
           << "state=" << state;
     }
+    // need to actively flush the writer
+    // so the data get pushed out.
+    if (state_ == kWaitForAsyncCallback) {
+      flush_writer_();
+    }
     state_ = state;
     CHECK(state != kInitHeader)
         << "cannot switch to init header";
@@ -389,41 +411,50 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
     this->Read(&type_hint);
     size_t elem_bytes = (type_hint.bits * type_hint.lanes + 7) / 8;
 
-    char* data_ptr;
     auto* sess = GetServingSession();
 
+    // Return Copy Ack with the given data
+    auto fcopyack = [this](char* data_ptr, size_t num_bytes) {
+      RPCCode code = RPCCode::kCopyAck;
+      uint64_t packet_nbytes = sizeof(code) + num_bytes;
+
+      this->Write(packet_nbytes);
+      this->Write(code);
+      this->WriteArray(data_ptr, num_bytes);
+      this->SwitchToState(kRecvPacketNumBytes);
+    };
+
     // When session is local, we can directly treat handle
     // as the cpu pointer without allocating a temp space.
     if (ctx.device_type == kDLCPU &&
         sess->IsLocalSession() &&
         DMLC_IO_NO_ENDIAN_SWAP) {
-      data_ptr = reinterpret_cast<char*>(handle) + offset;
+      char* data_ptr = reinterpret_cast<char*>(handle) + offset;
+      fcopyack(data_ptr, num_bytes);
     } else {
-      try {
-        data_ptr = this->ArenaAlloc<char>(num_bytes);
-        sess->CopyFromRemote(
-            reinterpret_cast<void*>(handle), offset,
-            data_ptr, 0,
-            num_bytes, ctx, type_hint);
-        // endian aware handling
-        if (!DMLC_IO_NO_ENDIAN_SWAP) {
-          dmlc::ByteSwap(data_ptr, elem_bytes, num_bytes / elem_bytes);
+      char* data_ptr = this->ArenaAlloc<char>(num_bytes);
+
+      auto on_copy_complete = [this, elem_bytes, num_bytes, data_ptr, 
fcopyack](
+          RPCCode status, TVMArgs args) {
+        if (status == RPCCode::kException) {
+          this->ReturnException(args.values[0].v_str);
+          this->SwitchToState(kRecvPacketNumBytes);
+        } else {
+          // endian aware handling
+          if (!DMLC_IO_NO_ENDIAN_SWAP) {
+            dmlc::ByteSwap(data_ptr, elem_bytes, num_bytes / elem_bytes);
+          }
+          fcopyack(data_ptr, num_bytes);
         }
-      } catch (const std::runtime_error &e) {
-        this->ReturnException(e.what());
-        this->SwitchToState(kRecvPacketNumBytes);
-        return;
-      }
+      };
+
+      this->SwitchToState(kWaitForAsyncCallback);
+      sess->AsyncCopyFromRemote(
+          reinterpret_cast<void*>(handle), offset,
+          data_ptr, 0,
+          num_bytes, ctx, type_hint,
+          on_copy_complete);
     }
-    RPCCode code = RPCCode::kCopyAck;
-    uint64_t packet_nbytes = sizeof(code) + num_bytes;
-
-    // Return Copy Ack
-    this->Write(packet_nbytes);
-    this->Write(code);
-    this->WriteArray(data_ptr, num_bytes);
-
-    this->SwitchToState(kRecvPacketNumBytes);
   }
 
   void HandleCopyToRemote() {
@@ -446,9 +477,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
        char* dptr = reinterpret_cast<char*>(handle) + offset;
        this->ReadArray(dptr, num_bytes);
 
-        if (!DMLC_IO_NO_ENDIAN_SWAP) {
-          dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes);
-        }
+       if (!DMLC_IO_NO_ENDIAN_SWAP) {
+         dmlc::ByteSwap(dptr, elem_bytes, num_bytes / elem_bytes);
+       }
+       this->ReturnVoid();
+       this->SwitchToState(kRecvPacketNumBytes);
     } else {
       char* temp_data = this->ArenaAlloc<char>(num_bytes);
       this->ReadArray(temp_data, num_bytes);
@@ -457,20 +490,23 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
         dmlc::ByteSwap(temp_data, elem_bytes, num_bytes / elem_bytes);
       }
 
-      try {
-        sess->CopyToRemote(
+      auto on_copy_complete = [this](RPCCode status, TVMArgs args) {
+        if (status == RPCCode::kException) {
+          this->ReturnException(args.values[0].v_str);
+          this->SwitchToState(kRecvPacketNumBytes);
+        } else {
+          this->ReturnVoid();
+          this->SwitchToState(kRecvPacketNumBytes);
+        }
+      };
+
+      this->SwitchToState(kWaitForAsyncCallback);
+      sess->AsyncCopyToRemote(
             temp_data, 0,
             reinterpret_cast<void*>(handle), offset,
-            num_bytes, ctx, type_hint);
-      } catch (const std::runtime_error &e) {
-        this->ReturnException(e.what());
-        this->SwitchToState(kRecvPacketNumBytes);
-        return;
-      }
+            num_bytes, ctx, type_hint,
+            on_copy_complete);
     }
-
-    this->ReturnVoid();
-    this->SwitchToState(kRecvPacketNumBytes);
   }
 
   // Handle for packed call.
@@ -480,16 +516,18 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
     this->Read(&call_handle);
     TVMArgs args = RecvPackedSeq();
 
-    try {
-      GetServingSession()->CallFunc(
-          reinterpret_cast<void*>(call_handle),
-          args.values, args.type_codes, args.size(),
-          [this](TVMArgs ret) { this->ReturnPackedSeq(ret); });
-    } catch (const std::runtime_error& e) {
-      this->ReturnException(e.what());
-    }
-
-    this->SwitchToState(kRecvPacketNumBytes);
+    this->SwitchToState(kWaitForAsyncCallback);
+    GetServingSession()->AsyncCallFunc(
+        reinterpret_cast<void*>(call_handle),
+        args.values, args.type_codes, args.size(),
+        [this](RPCCode status, TVMArgs args) {
+          if (status == RPCCode::kException) {
+            this->ReturnException(args.values[0].v_str);
+          } else {
+            this->ReturnPackedSeq(args);
+          }
+          this->SwitchToState(kRecvPacketNumBytes);
+        });
   }
 
   void HandleInitServer() {
@@ -512,35 +550,39 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
           << " server protocol=" << server_protocol_ver
           << ", client protocol=" << client_protocol_ver;
 
+      std::string constructor_name;
+      TVMArgs constructor_args = TVMArgs(nullptr, nullptr, 0);
+
       if (args.size() == 0) {
+        constructor_name = "rpc.LocalSession";
         serving_session_ = std::make_shared<LocalSession>();
       } else {
-        std::string constructor_name = args[0];
-        auto* fconstructor = Registry::Get(constructor_name);
-        CHECK(fconstructor != nullptr)
-            << " Cannot find session constructor " << constructor_name;
-        TVMRetValue con_ret;
-
-        try {
-          fconstructor->CallPacked(
-              TVMArgs(args.values + 1, args.type_codes + 1, args.size() - 1), 
&con_ret);
-        } catch (const dmlc::Error& e) {
-          LOG(FATAL) << "Server[" << name_ << "]:"
-                     << " Error caught from session constructor " << 
constructor_name
-                     << ":\n" << e.what();
-        }
+        constructor_name = args[0].operator std::string();
+        constructor_args = TVMArgs(args.values + 1, args.type_codes + 1, 
args.size() - 1);
+      }
+
+      auto* fconstructor = Registry::Get(constructor_name);
+      CHECK(fconstructor != nullptr)
+          << " Cannot find session constructor " << constructor_name;
+      TVMRetValue con_ret;
 
-        CHECK_EQ(con_ret.type_code(), kTVMModuleHandle)
-            << "Server[" << name_ << "]:"
-            << " Constructor " << constructor_name
-            << " need to return an RPCModule";
-        Module mod = con_ret;
-        std::string tkey = mod->type_key();
-        CHECK_EQ(tkey, "rpc")
-            << "Constructor " << constructor_name << " to return an RPCModule";
-        serving_session_ = RPCModuleGetSession(mod);
+      try {
+        fconstructor->CallPacked(constructor_args, &con_ret);
+      } catch (const dmlc::Error& e) {
+        LOG(FATAL) << "Server[" << name_ << "]:"
+                   << " Error caught from session constructor " << 
constructor_name
+                   << ":\n" << e.what();
       }
 
+      CHECK_EQ(con_ret.type_code(), kTVMModuleHandle)
+          << "Server[" << name_ << "]:"
+          << " Constructor " << constructor_name
+          << " need to return an RPCModule";
+      Module mod = con_ret;
+      std::string tkey = mod->type_key();
+      CHECK_EQ(tkey, "rpc")
+          << "Constructor " << constructor_name << " to return an RPCModule";
+      serving_session_ = RPCModuleGetSession(mod);
       this->ReturnVoid();
     } catch (const std::runtime_error &e) {
       this->ReturnException(e.what());
@@ -549,6 +591,28 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
     this->SwitchToState(kRecvPacketNumBytes);
   }
 
+  void HandleSyscallStreamSync() {
+    TVMArgs args = RecvPackedSeq();
+    try {
+      TVMContext ctx = args[0];
+      TVMStreamHandle handle = args[1];
+
+      this->SwitchToState(kWaitForAsyncCallback);
+      GetServingSession()->AsyncStreamWait(
+          ctx, handle, [this](RPCCode status, TVMArgs args) {
+            if (status == RPCCode::kException) {
+              this->ReturnException(args.values[0].v_str);
+            } else {
+              this->ReturnVoid();
+            }
+            this->SwitchToState(kRecvPacketNumBytes);
+          });
+    } catch (const std::runtime_error& e) {
+      this->ReturnException(e.what());
+      this->SwitchToState(kRecvPacketNumBytes);
+    }
+  }
+
   // Handler for special syscalls that have a specific RPCCode.
   template<typename F>
   void SysCallHandler(F f) {
@@ -572,6 +636,9 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
   RPCSession* GetServingSession() const {
     CHECK(serving_session_ != nullptr)
         << "Need to call InitRemoteSession first before any further actions";
+    CHECK(!serving_session_->IsAsync() || async_server_mode_)
+        << "Cannot host an async session in a non-Event driven server";
+
     return serving_session_.get();
   }
   // Utility functions
@@ -598,10 +665,13 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
   std::string name_;
   // remote key
   std::string* remote_key_;
+  // function to flush the writer.
+  std::function<void()> flush_writer_;
 };
 
 RPCCode RPCEndpoint::HandleUntilReturnEvent(
-    bool client_mode, RPCSession::FEncodeReturn setreturn) {
+    bool client_mode,
+    RPCSession::FEncodeReturn setreturn) {
   RPCCode code = RPCCode::kCallFunc;
   while (code != RPCCode::kReturn &&
          code != RPCCode::kShutdown &&
@@ -624,15 +694,26 @@ RPCCode RPCEndpoint::HandleUntilReturnEvent(
         }
       }
     }
-    code = handler_->HandleNextEvent(client_mode, setreturn);
+    code = handler_->HandleNextEvent(client_mode, false, setreturn);
   }
   return code;
 }
 
 void RPCEndpoint::Init() {
+  // callback to flush the writer.
+  auto flush_writer = [this]() {
+    while (writer_.bytes_available() != 0) {
+      size_t n = writer_.ReadWithCallback([this](const void *data, size_t 
size) {
+        return channel_->Send(data, size);
+      }, writer_.bytes_available());
+      if (n == 0) break;
+    }
+  };
+
   // Event handler
   handler_ = std::make_shared<EventHandler>(
-      &reader_, &writer_, name_, &remote_key_);
+      &reader_, &writer_, name_, &remote_key_, flush_writer);
+
   // Quick function to for syscall remote.
   syscall_remote_ = PackedFunc([this](TVMArgs all_args, TVMRetValue* rv) {
     std::lock_guard<std::mutex> lock(mutex_);
@@ -711,7 +792,7 @@ int RPCEndpoint::ServerAsyncIOEventHandler(const 
std::string& in_bytes, int even
   RPCCode code = RPCCode::kNone;
   if (in_bytes.length() != 0) {
     reader_.Write(in_bytes.c_str(), in_bytes.length());
-    code = handler_->HandleNextEvent(false, [](TVMArgs) {});
+    code = handler_->HandleNextEvent(false, true, [](TVMArgs) {});
   }
   if ((event_flag & 2) != 0 && writer_.bytes_available() != 0) {
     writer_.ReadWithCallback([this](const void *data, size_t size) {
@@ -894,12 +975,6 @@ void RPCDevFreeData(RPCSession* handler, TVMArgs args, 
TVMRetValue *rv) {
   handler->GetDeviceAPI(ctx)->FreeDataSpace(ctx, ptr);
 }
 
-void RPCDevStreamSync(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
-  TVMContext ctx = args[0];
-  TVMStreamHandle handle = args[1];
-  handler->GetDeviceAPI(ctx)->StreamSync(ctx, handle);
-}
-
 void RPCCopyAmongRemote(RPCSession* handler, TVMArgs args, TVMRetValue *rv) {
   void* from = args[0];
   uint64_t from_offset = args[1];
@@ -935,12 +1010,14 @@ void RPCEndpoint::EventHandler::HandleSyscall(RPCCode 
code) {
     case RPCCode::kDevGetAttr: SysCallHandler(RPCDevGetAttr); break;
     case RPCCode::kDevAllocData: SysCallHandler(RPCDevAllocData); break;
     case RPCCode::kDevFreeData: SysCallHandler(RPCDevFreeData); break;
-    case RPCCode::kDevStreamSync: SysCallHandler(RPCDevStreamSync); break;
+    case RPCCode::kDevStreamSync: this->HandleSyscallStreamSync(); break;
     case RPCCode::kCopyAmongRemote: SysCallHandler(RPCCopyAmongRemote); break;
     default: LOG(FATAL) << "Unknown event " << static_cast<int>(code);
   }
 
-  CHECK_EQ(state_, kRecvPacketNumBytes);
+  if (state_ != kWaitForAsyncCallback) {
+    CHECK_EQ(state_, kRecvPacketNumBytes);
+  }
 }
 
 /*!
diff --git a/src/runtime/rpc/rpc_local_session.cc 
b/src/runtime/rpc/rpc_local_session.cc
index 351a989..9d1fb72 100644
--- a/src/runtime/rpc/rpc_local_session.cc
+++ b/src/runtime/rpc/rpc_local_session.cc
@@ -31,21 +31,15 @@ namespace runtime {
 
 RPCSession::PackedFuncHandle
 LocalSession::GetFunction(const std::string& name) {
-  PackedFunc pf = this->GetFunctionInternal(name);
-  // return raw handl because the remote need to explicitly manage it.
-  if (pf != nullptr) return new PackedFunc(pf);
-  return nullptr;
+  if (auto* fp = tvm::runtime::Registry::Get(name)) {
+    // return raw handle because the remote need to explicitly manage it.
+    return new PackedFunc(*fp);
+  } else {
+    return nullptr;
+  }
 }
 
-void LocalSession::CallFunc(RPCSession::PackedFuncHandle func,
-                            const TVMValue* arg_values,
-                            const int* arg_type_codes,
-                            int num_args,
-                            const FEncodeReturn& encode_return) {
-  auto* pf = static_cast<PackedFunc*>(func);
-  TVMRetValue rv;
-
-  pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv);
+void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& 
encode_return) {
   int rv_tcode = rv.type_code();
 
   // return value encoding.
@@ -84,6 +78,17 @@ void LocalSession::CallFunc(RPCSession::PackedFuncHandle 
func,
   }
 }
 
+void LocalSession::CallFunc(RPCSession::PackedFuncHandle func,
+                            const TVMValue* arg_values,
+                            const int* arg_type_codes,
+                            int num_args,
+                            const FEncodeReturn& encode_return) {
+  auto* pf = static_cast<PackedFunc*>(func);
+  TVMRetValue rv;
+  pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv);
+  this->EncodeReturn(std::move(rv), encode_return);
+}
+
 void LocalSession::CopyToRemote(void* from,
                                 size_t from_offset,
                                 void* to,
@@ -134,15 +139,6 @@ DeviceAPI* LocalSession::GetDeviceAPI(TVMContext ctx, bool 
allow_missing) {
   return DeviceAPI::Get(ctx, allow_missing);
 }
 
-PackedFunc LocalSession::GetFunctionInternal(const std::string& name) {
-  auto* fp = tvm::runtime::Registry::Get(name);
-  if (fp != nullptr) {
-    return *fp;
-  } else {
-    return nullptr;
-  }
-}
-
 TVM_REGISTER_GLOBAL("rpc.LocalSession")
 .set_body_typed([]() {
   return CreateRPCSessionModule(std::make_shared<LocalSession>());
diff --git a/src/runtime/rpc/rpc_local_session.h 
b/src/runtime/rpc/rpc_local_session.h
index 3b6e7d8..ff0caa4 100644
--- a/src/runtime/rpc/rpc_local_session.h
+++ b/src/runtime/rpc/rpc_local_session.h
@@ -28,6 +28,7 @@
 #include <tvm/runtime/device_api.h>
 #include <functional>
 #include <string>
+#include <utility>
 #include "rpc_session.h"
 
 namespace tvm {
@@ -40,13 +41,13 @@ namespace runtime {
 class LocalSession : public RPCSession {
  public:
   // function overrides
-  PackedFuncHandle GetFunction(const std::string& name) final;
+  PackedFuncHandle GetFunction(const std::string& name) override;
 
   void CallFunc(PackedFuncHandle func,
                 const TVMValue* arg_values,
                 const int* arg_type_codes,
                 int num_args,
-                const FEncodeReturn& fencode_return) final;
+                const FEncodeReturn& fencode_return) override;
 
   void CopyToRemote(void* from,
                     size_t from_offset,
@@ -54,7 +55,7 @@ class LocalSession : public RPCSession {
                     size_t to_offset,
                     size_t nbytes,
                     TVMContext ctx_to,
-                    DLDataType type_hint) final;
+                    DLDataType type_hint) override;
 
   void CopyFromRemote(void* from,
                       size_t from_offset,
@@ -62,23 +63,23 @@ class LocalSession : public RPCSession {
                       size_t to_offset,
                       size_t nbytes,
                       TVMContext ctx_from,
-                      DLDataType type_hint) final;
+                      DLDataType type_hint) override;
 
-  void FreeHandle(void* handle, int type_code) final;
+  void FreeHandle(void* handle, int type_code) override;
 
-  DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) final;
+  DeviceAPI* GetDeviceAPI(TVMContext ctx, bool allow_missing = false) override;
 
-  bool IsLocalSession() const final {
+  bool IsLocalSession() const override {
     return true;
   }
 
  protected:
   /*!
-   * \brief Internal implementation of GetFunction.
-   * \param name The name of the function.
-   * \return The corresponding PackedFunc.
+   * \brief internal encode return fucntion.
+   * \param rv The return value.
+   * \param encode_return The encoding function.
    */
-  virtual PackedFunc GetFunctionInternal(const std::string& name);
+  void EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_return);
 };
 
 }  // namespace runtime
diff --git a/src/runtime/rpc/rpc_session.cc b/src/runtime/rpc/rpc_session.cc
index dd0afa0..d07aa74 100644
--- a/src/runtime/rpc/rpc_session.cc
+++ b/src/runtime/rpc/rpc_session.cc
@@ -30,6 +30,93 @@
 namespace tvm {
 namespace runtime {
 
+bool RPCSession::IsAsync() const {
+  return false;
+}
+
+void RPCSession::SendException(FAsyncCallback callback, const char* msg) {
+  TVMValue value;
+  value.v_str = msg;
+  int32_t tcode = kTVMStr;
+  callback(RPCCode::kException, TVMArgs(&value, &tcode, 1));
+}
+
+void RPCSession::AsyncCallFunc(PackedFuncHandle func,
+                               const TVMValue* arg_values,
+                               const int* arg_type_codes,
+                               int num_args,
+                               FAsyncCallback callback) {
+  try {
+    this->CallFunc(func, arg_values, arg_type_codes, num_args,
+                   [&callback](TVMArgs args) {
+                     callback(RPCCode::kReturn, args);
+                   });
+  } catch (const std::runtime_error& e) {
+    this->SendException(callback, e.what());
+  }
+}
+
+
+void RPCSession::AsyncCopyToRemote(void* local_from,
+                                   size_t local_from_offset,
+                                   void* remote_to,
+                                   size_t remote_to_offset,
+                                   size_t nbytes,
+                                   TVMContext remote_ctx_to,
+                                   DLDataType type_hint,
+                                   RPCSession::FAsyncCallback callback) {
+  TVMValue value;
+  int32_t tcode = kTVMNullptr;
+  value.v_handle = nullptr;
+
+  try {
+    this->CopyToRemote(local_from, local_from_offset,
+                       remote_to, remote_to_offset,
+                       nbytes, remote_ctx_to, type_hint);
+    callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1));
+  } catch (const std::runtime_error& e) {
+    this->SendException(callback, e.what());
+  }
+}
+
+void RPCSession::AsyncCopyFromRemote(void* remote_from,
+                                     size_t remote_from_offset,
+                                     void* local_to,
+                                     size_t local_to_offset,
+                                     size_t nbytes,
+                                     TVMContext remote_ctx_from,
+                                     DLDataType type_hint,
+                                     RPCSession::FAsyncCallback callback) {
+  TVMValue value;
+  int32_t tcode = kTVMNullptr;
+  value.v_handle = nullptr;
+
+  try {
+    this->CopyFromRemote(remote_from, remote_from_offset,
+                         local_to, local_to_offset,
+                         nbytes, remote_ctx_from, type_hint);
+    callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1));
+  } catch (const std::runtime_error& e) {
+    this->SendException(callback, e.what());
+  }
+}
+
+void RPCSession::AsyncStreamWait(TVMContext ctx,
+                                 TVMStreamHandle stream,
+                                 RPCSession::FAsyncCallback callback) {
+  TVMValue value;
+  int32_t tcode = kTVMNullptr;
+  value.v_handle = nullptr;
+
+  try {
+    this->GetDeviceAPI(ctx)->StreamSync(ctx, stream);
+    callback(RPCCode::kReturn, TVMArgs(&value, &tcode, 1));
+  } catch (const std::runtime_error& e) {
+    this->SendException(callback, e.what());
+  }
+}
+
+
 class RPCSessTable {
  public:
   static constexpr int kMaxRPCSession = 32;
diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h
index e7e4433..7ea1eb9 100644
--- a/src/runtime/rpc/rpc_session.h
+++ b/src/runtime/rpc/rpc_session.h
@@ -30,6 +30,7 @@
 #include <functional>
 #include <memory>
 #include <string>
+#include "rpc_protocol.h"
 
 namespace tvm {
 namespace runtime {
@@ -58,7 +59,6 @@ class RPCSession {
    * \brief Callback to send an encoded return values via encode_args.
    *
    * \param encode_args The arguments that we can encode the return values 
into.
-   * \param ret_tcode The actual remote type code of the return value.
    *
    * Encoding convention (as list of arguments):
    * - str/float/int/byte: [tcode: int, value: TVMValue] value follows 
PackedFunc convention.
@@ -69,6 +69,14 @@ class RPCSession {
    */
   using FEncodeReturn = std::function<void(TVMArgs encoded_args)>;
 
+  /*!
+   * \brief Callback to send an encoded return values via encode_args.
+   *
+   * \param status The return status, can be RPCCode::kReturn or 
RPCCode::kException.
+   * \param encode_args The arguments that we can encode the return values 
into.
+   */
+  using FAsyncCallback = std::function<void(RPCCode status, TVMArgs 
encoded_args)>;
+
   /*! \brief Destructor.*/
   virtual ~RPCSession() {}
 
@@ -189,6 +197,98 @@ class RPCSession {
    */
   virtual bool IsLocalSession() const = 0;
 
+  // Asynchrous variant of API
+  // These APIs are used by the RPC server to allow sessions that
+  // have special implementations for the async functions.
+  //
+  // In the async APIs, an exception is returned by the passing
+  // async_error=true, encode_args=[error_msg].
+
+  /*!
+   * \brief Whether the session is async.
+   *
+   * If the session is not async, its Aync implementations
+   * simply calls into the their synchronize counterparts,
+   * and the callback is guaranteed to be called before the async function 
finishes.
+   *
+   * \return the async state.
+   *
+   * \note We can only use async session in an Event driven RPC server.
+   */
+  virtual bool IsAsync() const;
+
+  /*!
+   * \brief Asynchrously call func.
+   * \param func The function handle.
+   * \param arg_values The argument values.
+   * \param arg_type_codes the type codes of the argument.
+   * \param num_args Number of arguments.
+   *
+   * \param callback The callback to pass the return value or exception.
+   */
+  virtual void AsyncCallFunc(PackedFuncHandle func,
+                             const TVMValue* arg_values,
+                             const int* arg_type_codes,
+                             int num_args,
+                             FAsyncCallback callback);
+
+  /*!
+   * \brief Asynchrous version of CopyToRemote.
+   *
+   * \param local_from The source host data.
+   * \param local_from_offset The byte offeset in the from.
+   * \param remote_to The target array.
+   * \param remote_to_offset The byte offset in the to.
+   * \param nbytes The size of the memory in bytes.
+   * \param remote_ctx_to The target context.
+   * \param type_hint Hint of content data type.
+   *
+   * \param on_complete The callback to signal copy complete.
+   * \note All the allocated memory in local_from, and remote_to
+   *       must stay alive until on_compelete is called.
+   */
+  virtual void AsyncCopyToRemote(void* local_from,
+                                 size_t local_from_offset,
+                                 void* remote_to,
+                                 size_t remote_to_offset,
+                                 size_t nbytes,
+                                 TVMContext remote_ctx_to,
+                                 DLDataType type_hint,
+                                 FAsyncCallback on_complete);
+
+  /*!
+   * \brief Asynchrous version of CopyFromRemote.
+   *
+   * \param remote_from The source host data.
+   * \param remote_from_offset The byte offeset in the from.
+   * \param to The target array.
+   * \param to_offset The byte offset in the to.
+   * \param nbytes The size of the memory in bytes.
+   * \param remote_ctx_from The source context in the remote.
+   * \param type_hint Hint of content data type.
+   *
+   * \param on_complete The callback to signal copy complete.
+   * \note All the allocated memory in remote_from, and local_to
+   *       must stay alive until on_compelete is called.
+   */
+  virtual void AsyncCopyFromRemote(void* remote_from,
+                                   size_t remote_from_offset,
+                                   void* local_to,
+                                   size_t local_to_offset,
+                                   size_t nbytes,
+                                   TVMContext remote_ctx_from,
+                                   DLDataType type_hint,
+                                   FAsyncCallback on_complete);
+  /*!
+   * \brief Asynchrously wait for all events in ctx, stream compeletes.
+   * \param ctx The device context.
+   * \param stream The stream to wait on.
+   * \param on_complete The callback to signal copy complete.
+   */
+  virtual void AsyncStreamWait(TVMContext ctx,
+                               TVMStreamHandle stream,
+                               FAsyncCallback on_compelte);
+
   /*!
    * \return The session table index of the session.
    */
@@ -203,6 +303,13 @@ class RPCSession {
    */
   static std::shared_ptr<RPCSession> Get(int table_index);
 
+ protected:
+  /*!
+   * \brief Send an exception to the callback.
+   * \param msg The exception message.
+   */
+  void SendException(FAsyncCallback callback, const char* msg);
+
  private:
   /*! \brief index of this session in RPC session table */
   int table_index_{0};

Reply via email to