This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4258c864b9 [RUNTIME][RPC] Enable RPCObjectRef return in RPC (#16387)
4258c864b9 is described below

commit 4258c864b91f1b0b5cffc5ba792a331998f793bd
Author: Tianqi Chen <tqc...@users.noreply.github.com>
AuthorDate: Fri Jan 12 11:53:19 2024 -0500

    [RUNTIME][RPC] Enable RPCObjectRef return in RPC (#16387)
    
    [Runtime] Enable RPCObjectRef return in RPC
    
    This PR enables RPCObjectRef return object similar to the disco 
transporation.
    This allows us to do advanced remote debugging when remote vm requires
    advanced object input like kv cache and shape.
    
    To keep the implementation with minRPC(used in some of the limited 
protocols) forn now,
    we only support RPCObjectRef for now and do not enable unpacking Shape and 
String.
---
 include/tvm/runtime/object.h             |  4 ++-
 src/runtime/minrpc/minrpc_server.h       | 15 ++++++++--
 src/runtime/minrpc/rpc_reference.h       |  8 +++++
 src/runtime/rpc/rpc_endpoint.cc          | 51 +++++++++++++++++++++++++++-----
 src/runtime/rpc/rpc_local_session.cc     | 20 +++++++++++--
 src/runtime/rpc/rpc_module.cc            |  7 +++++
 src/runtime/rpc/rpc_session.h            | 51 +++++++++++++++++++++++++++++++-
 tests/python/runtime/test_runtime_rpc.py | 31 +++++++++++++++++++
 8 files changed, 174 insertions(+), 13 deletions(-)

diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h
index 94644d797c..92f477b058 100644
--- a/include/tvm/runtime/object.h
+++ b/include/tvm/runtime/object.h
@@ -72,8 +72,10 @@ struct TypeIndex {
     kRuntimeShapeTuple = 6,
     /*! \brief runtime::PackedFunc. */
     kRuntimePackedFunc = 7,
-    /*! \brief runtime::DRef */
+    /*! \brief runtime::DRef for disco distributed runtime */
     kRuntimeDiscoDRef = 8,
+    /*! \brief runtime::RPCObjectRef */
+    kRuntimeRPCObjectRef = 9,
     // static assignments that may subject to change.
     kRuntimeClosure,
     kRuntimeADT,
diff --git a/src/runtime/minrpc/minrpc_server.h 
b/src/runtime/minrpc/minrpc_server.h
index cca47f80b9..96a4dbce79 100644
--- a/src/runtime/minrpc/minrpc_server.h
+++ b/src/runtime/minrpc/minrpc_server.h
@@ -206,7 +206,8 @@ class MinRPCExecute : public MinRPCExecInterface {
         ret_tcode[1] = kTVMBytes;
         ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
         
TVMByteArrayFree(reinterpret_cast<TVMByteArray*>(ret_value[1].v_handle));  // 
NOLINT(*)
-      } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == 
kTVMModuleHandle) {
+      } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == 
kTVMModuleHandle ||
+                 rv_tcode == kTVMObjectHandle) {
         ret_tcode[1] = kTVMOpaqueHandle;
         ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
       } else {
@@ -755,7 +756,17 @@ class MinRPCServer {
   }
 
   void ReadObject(int* tcode, TVMValue* value) {
-    this->ThrowError(RPCServerStatus::kUnknownTypeCode);
+    // handles RPCObject in minRPC
+    // NOTE: object needs to be supported by C runtime
+    // because minrpc's restriction of C only
+    // we only handle RPCObjectRef
+    uint32_t type_index;
+    Read(&type_index);
+    MINRPC_CHECK(type_index == kRuntimeRPCObjectRefTypeIndex);
+    uint64_t object_handle;
+    Read(&object_handle);
+    tcode[0] = kTVMObjectHandle;
+    value[0].v_handle = reinterpret_cast<void*>(object_handle);
   }
 
  private:
diff --git a/src/runtime/minrpc/rpc_reference.h 
b/src/runtime/minrpc/rpc_reference.h
index e16f09cb9d..732b017e44 100644
--- a/src/runtime/minrpc/rpc_reference.h
+++ b/src/runtime/minrpc/rpc_reference.h
@@ -33,6 +33,14 @@ class Object;
 /*! \brief The current RPC procotol version. */
 constexpr const char* kRPCProtocolVer = "0.8.0";
 
+/*!
+ * \brief type index of kRuntimeRPCObjectRefTypeIndex
+ * \note this needs to be kept consistent with runtime/object.h
+ * but we explicitly declare it here because minrpc needs to be minimum dep
+ * only c C API
+ */
+constexpr const int kRuntimeRPCObjectRefTypeIndex = 9;
+
 // When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered.
 const uint64_t kRPCMaxTransferSizeBytesDefault = UINT64_MAX;
 
diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc
index f2c09132fc..2c431cdb64 100644
--- a/src/runtime/rpc/rpc_endpoint.cc
+++ b/src/runtime/rpc/rpc_endpoint.cc
@@ -175,8 +175,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
     for (int i = 0; i < num_args; ++i) {
       int tcode = type_codes[i];
       if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) {
-        LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type "
-                   << args[i].AsObjectRef<ObjectRef>()->GetTypeKey() << " is 
not supported by RPC";
+        if (!args[i].IsObjectRef<RPCObjectRef>()) {
+          LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type "
+                     << args[i].AsObjectRef<ObjectRef>()->GetTypeKey()
+                     << " is not supported by RPC";
+        }
       } else if (tcode == kDLDevice) {
         DLDevice dev = args[i];
         ICHECK(!IsRPCSessionDevice(dev)) << "InternalError: cannot pass RPC 
device in the channel";
@@ -219,14 +222,48 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
     this->Write(cdata);
   }
 
-  void WriteObject(void* obj) { 
this->ThrowError(RPCServerStatus::kUnknownTypeCode); }
-  uint64_t GetObjectBytes(void* obj) {
-    this->ThrowError(RPCServerStatus::kUnknownTypeCode);
-    return 0;
+  void WriteObject(Object* obj) {
+    // NOTE: for now all remote object are encoded as RPCObjectRef
+    // follow the same disco protocol in case we would like to upgrade later
+    //
+    // Rationale note: Only handle remote object allows the same mechanism to 
work for minRPC
+    // which is needed for wasm and other env that goes through C API
+    if (obj->IsInstance<RPCObjectRefObj>()) {
+      auto* ref = static_cast<RPCObjectRefObj*>(obj);
+      this->template Write<uint32_t>(kRuntimeRPCObjectRefTypeIndex);
+      uint64_t handle = reinterpret_cast<uint64_t>(ref->object_handle());
+      this->template Write<int64_t>(handle);
+    } else {
+      LOG(FATAL) << "ValueError: Object type is not supported in RPC calling 
convention: "
+                 << obj->GetTypeKey() << " (type_index = " << 
obj->type_index() << ")";
+    }
+  }
+  uint64_t GetObjectBytes(Object* obj) {
+    if (obj->IsInstance<RPCObjectRefObj>()) {
+      return sizeof(uint32_t) + sizeof(int64_t);
+    } else {
+      LOG(FATAL) << "ValueError: Object type is not supported in RPC calling 
convention: "
+                 << obj->GetTypeKey() << " (type_index = " << 
obj->type_index() << ")";
+    }
   }
 
   void ReadObject(int* tcode, TVMValue* value) {
-    this->ThrowError(RPCServerStatus::kUnknownTypeCode);
+    // NOTE: for now all remote object are encoded as RPCObjectRef
+    // follow the same disco protocol in case we would like to upgrade later
+    //
+    // Rationale note: Only handle remote object allows the same mechanism to 
work for minRPC
+    // which is needed for wasm and other env that goes through C API
+    uint32_t type_index;
+    this->template Read<uint32_t>(&type_index);
+    if (type_index == kRuntimeRPCObjectRefTypeIndex) {
+      uint64_t handle;
+      this->template Read<uint64_t>(&handle);
+      tcode[0] = kTVMObjectHandle;
+      value[0].v_handle = reinterpret_cast<void*>(handle);
+    } else {
+      LOG(FATAL) << "ValueError: Object type is not supported in Disco calling 
convention: "
+                 << Object::TypeIndex2Key(type_index) << " (type_index = " << 
type_index << ")";
+    }
   }
 
   void MessageDone() {
diff --git a/src/runtime/rpc/rpc_local_session.cc 
b/src/runtime/rpc/rpc_local_session.cc
index d4aec5596f..92691ee6fd 100644
--- a/src/runtime/rpc/rpc_local_session.cc
+++ b/src/runtime/rpc/rpc_local_session.cc
@@ -27,6 +27,7 @@
 #include <tvm/runtime/registry.h>
 
 #include <memory>
+#include <vector>
 
 namespace tvm {
 namespace runtime {
@@ -64,7 +65,8 @@ void LocalSession::EncodeReturn(TVMRetValue rv, const 
FEncodeReturn& encode_retu
     ret_value_pack[2].v_handle = ret_value_pack[1].v_handle;
     ret_tcode_pack[2] = kTVMOpaqueHandle;
     encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3));
-  } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) 
{
+  } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle 
||
+             rv_tcode == kTVMObjectHandle) {
     // MoveToCHost means rv no longer manages the object.
     // return handle instead.
     rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]);
@@ -88,7 +90,21 @@ void LocalSession::CallFunc(RPCSession::PackedFuncHandle 
func, const TVMValue* a
                             const FEncodeReturn& encode_return) {
   PackedFuncObj* pf = static_cast<PackedFuncObj*>(func);
   TVMRetValue rv;
-  pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv);
+
+  // unwrap RPCObjectRef in case we are directly using it to call LocalSession
+  std::vector<TVMValue> values(arg_values, arg_values + num_args);
+  std::vector<int> type_codes(arg_type_codes, arg_type_codes + num_args);
+  TVMArgs args(arg_values, arg_type_codes, num_args);
+
+  for (int i = 0; i < num_args; ++i) {
+    if (args[i].IsObjectRef<RPCObjectRef>()) {
+      RPCObjectRef obj_ref = args[i];
+      values[i].v_handle = obj_ref->object_handle();
+      continue;
+    }
+  }
+
+  pf->CallPacked(TVMArgs(values.data(), type_codes.data(), args.size()), &rv);
   this->EncodeReturn(std::move(rv), encode_return);
 }
 
diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc
index 94f6720ca8..a696005ab8 100644
--- a/src/runtime/rpc/rpc_module.cc
+++ b/src/runtime/rpc/rpc_module.cc
@@ -157,6 +157,8 @@ class RPCWrappedFunc : public Object {
   }
 };
 
+TVM_REGISTER_OBJECT_TYPE(RPCObjectRefObj);
+
 // RPC that represents a remote module session.
 class RPCModuleNode final : public ModuleNode {
  public:
@@ -294,6 +296,11 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(TVMArgs args, 
TVMRetValue* rv) cons
     void* handle = args[1];
     auto n = make_object<RPCModuleNode>(handle, sess_);
     *rv = Module(n);
+  } else if (tcode == kTVMObjectHandle) {
+    ICHECK_EQ(args.size(), 2);
+    void* handle = args[1];
+    auto n = make_object<RPCObjectRefObj>(handle, sess_);
+    *rv = ObjectRef(n);
   } else if (tcode == kTVMDLTensorHandle || tcode == kTVMNDArrayHandle) {
     ICHECK_EQ(args.size(), 3);
     DLTensor* tensor = args[1];
diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h
index 60d067e49d..b09900d0ab 100644
--- a/src/runtime/rpc/rpc_session.h
+++ b/src/runtime/rpc/rpc_session.h
@@ -142,7 +142,7 @@ class RPCSession {
 
   /*!
    * \brief Free a remote function.
-   * \param handle The remote handle, can be NDArray/PackedFunc/Module
+   * \param handle The remote handle, can be NDArray/PackedFunc/Module/Object
    * \param type_code The type code of the underlying type.
    */
   virtual void FreeHandle(void* handle, int type_code) = 0;
@@ -287,6 +287,55 @@ struct RemoteSpace {
   std::shared_ptr<RPCSession> sess;
 };
 
+/*!
+ * \brief Object wrapper that represents a reference to a remote object
+ */
+class RPCObjectRefObj : public Object {
+ public:
+  /*!
+   * \brief constructor
+   * \param object_handle handle that points to the remote object
+   * \param sess The remote session
+   */
+  RPCObjectRefObj(void* object_handle, std::shared_ptr<RPCSession> sess)
+      : object_handle_(object_handle), sess_(sess) {}
+
+  ~RPCObjectRefObj() {
+    if (object_handle_ != nullptr) {
+      try {
+        sess_->FreeHandle(object_handle_, kTVMObjectHandle);
+      } catch (const Error& e) {
+        // fault tolerance to remote close
+      }
+      object_handle_ = nullptr;
+    }
+  }
+
+  const std::shared_ptr<RPCSession>& sess() const { return sess_; }
+
+  void* object_handle() const { return object_handle_; }
+
+  static constexpr const uint32_t _type_index = 
TypeIndex::kRuntimeRPCObjectRef;
+  static constexpr const char* _type_key = "runtime.RPCObjectRef";
+  TVM_DECLARE_FINAL_OBJECT_INFO(RPCObjectRefObj, Object);
+
+ private:
+  // The object handle
+  void* object_handle_{nullptr};
+  // The local channel
+  std::shared_ptr<RPCSession> sess_;
+};
+
+/*!
+ * \brief Managed reference to RPCObjectRefObj.
+ * \sa RPCObjectRefObj
+ * \note No public constructor is provided as it is not supposed to be 
directly created by users.
+ */
+class RPCObjectRef : public ObjectRef {
+ public:
+  TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RPCObjectRef, ObjectRef, 
RPCObjectRefObj);
+};
+
 /*!
  * \brief Create a Global RPC module that refers to the session.
  * \param sess The RPC session of the global module.
diff --git a/tests/python/runtime/test_runtime_rpc.py 
b/tests/python/runtime/test_runtime_rpc.py
index 9591e3ea4d..fff203df00 100644
--- a/tests/python/runtime/test_runtime_rpc.py
+++ b/tests/python/runtime/test_runtime_rpc.py
@@ -426,6 +426,7 @@ def test_rpc_return_ndarray():
     ref_count = m("ref_count")
     get_elem = m("get_elem")
     get_arr_elem = m("get_arr_elem")
+
     # array test
     def run_arr_test():
         arr = get_arr()
@@ -435,6 +436,36 @@ def test_rpc_return_ndarray():
     run_arr_test()
 
 
+@tvm.testing.requires_rpc
+def test_rpc_return_remote_object():
+    def check(client, is_local):
+        make_shape = client.get_function("runtime.ShapeTuple")
+        get_elem = client.get_function("runtime.GetShapeTupleElem")
+        get_size = client.get_function("runtime.GetShapeTupleSize")
+        shape = make_shape(2, 3)
+        assert shape.type_key == "runtime.RPCObjectRef"
+        assert get_elem(shape, 0) == 2
+        assert get_elem(shape, 1) == 3
+        assert get_size(shape) == 2
+
+    # start server
+    server = rpc.Server(key="x1")
+    client = rpc.connect("127.0.0.1", server.port, key="x1")
+    check(rpc.LocalSession(), True)
+    check(client, False)
+
+    def check_minrpc():
+        if tvm.get_global_func("rpc.CreatePipeClient", allow_missing=True) is 
None:
+            return
+        # Test minrpc server.
+        temp = utils.tempdir()
+        minrpc_exec = temp.relpath("minrpc")
+        tvm.rpc.with_minrpc(cc.create_executable)(minrpc_exec, [])
+        check(rpc.PopenSession(minrpc_exec), False)
+
+    check_minrpc()
+
+
 @tvm.testing.requires_rpc
 def test_local_func():
     client = rpc.LocalSession()

Reply via email to