This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0230c77cb7 [Disco] Support ShapeTuple in Disco Protocol (#15634)
0230c77cb7 is described below

commit 0230c77cb700ccc462bf990fc1e22f06a7943f92
Author: Junru Shao <[email protected]>
AuthorDate: Mon Aug 28 11:27:24 2023 -0700

    [Disco] Support ShapeTuple in Disco Protocol (#15634)
    
    ShapeTuple is an essential item used in Relax stack as the runtime
    representation of shapes. It is also used as a current workaround to
    represent integers (1-d shape) given standalone non-constant integers
    are currently absent in Relax.
    
    This PR introduces formal support for ShapeTuple in Disco's
    communication protocol based on the recent enhancement of TVM RPC system
    to support TVM Objects: https://github.com/apache/tvm/pull/15631.
---
 src/runtime/disco/threaded_session.cc | 22 ++++++++++++++------
 tests/python/disco/test_session.py    | 39 +++++++++++++++++++++++++++++++++++
 2 files changed, 55 insertions(+), 6 deletions(-)

diff --git a/src/runtime/disco/threaded_session.cc 
b/src/runtime/disco/threaded_session.cc
index 04860ef712..cb84918d2d 100644
--- a/src/runtime/disco/threaded_session.cc
+++ b/src/runtime/disco/threaded_session.cc
@@ -110,6 +110,9 @@ class DiscoThreadedMessageQueue : public dmlc::Stream {
   uint64_t GetObjectBytes(Object* obj) {
     if (obj->IsInstance<DRefObj>()) {
       return sizeof(uint32_t) + sizeof(int64_t);
+    } else if (obj->IsInstance<StringObj>()) {
+      uint64_t size = static_cast<StringObj*>(obj)->size;
+      return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char);
     } else if (obj->IsInstance<ShapeTupleObj>()) {
       uint64_t ndim = static_cast<ShapeTupleObj*>(obj)->size;
       return sizeof(uint32_t) + sizeof(uint64_t) + ndim * 
sizeof(ShapeTupleObj::index_type);
@@ -124,13 +127,16 @@ class DiscoThreadedMessageQueue : public dmlc::Stream {
       int64_t reg_id = static_cast<DRefObj*>(obj)->reg_id;
       this->Write<uint32_t>(TypeIndex::kRuntimeDiscoDRef);
       this->Write<int64_t>(reg_id);
+    } else if (obj->IsInstance<StringObj>()) {
+      StringObj* str = static_cast<StringObj*>(obj);
+      this->Write<uint32_t>(TypeIndex::kRuntimeString);
+      this->Write<uint64_t>(str->size);
+      this->WriteArray<char>(str->data, str->size);
     } else if (obj->IsInstance<ShapeTupleObj>()) {
       ShapeTupleObj* shape = static_cast<ShapeTupleObj*>(obj);
       this->Write<uint32_t>(TypeIndex::kRuntimeShapeTuple);
       this->Write<uint64_t>(shape->size);
-      for (uint64_t i = 0; i < shape->size; ++i) {
-        this->Write<ShapeTupleObj::index_type>(shape->data[i]);
-      }
+      this->WriteArray<ShapeTupleObj::index_type>(shape->data, shape->size);
     } else {
       LOG(FATAL) << "ValueError: Object type is not supported in Disco calling 
convention: "
                  << obj->GetTypeKey() << " (type_index = " << 
obj->type_index() << ")";
@@ -146,13 +152,17 @@ class DiscoThreadedMessageQueue : public dmlc::Stream {
       this->Read<int64_t>(&dref->reg_id);
       dref->session = Session{nullptr};
       result = ObjectRef(std::move(dref));
+    } else if (type_index == TypeIndex::kRuntimeString) {
+      uint64_t size = 0;
+      this->Read<uint64_t>(&size);
+      std::string data(size, '\0');
+      this->ReadArray<char>(data.data(), size);
+      result = String(std::move(data));
     } else if (type_index == TypeIndex::kRuntimeShapeTuple) {
       uint64_t ndim = 0;
       this->Read<uint64_t>(&ndim);
       std::vector<ShapeTupleObj::index_type> data(ndim);
-      for (ShapeTupleObj::index_type& i : data) {
-        this->Read<ShapeTupleObj::index_type>(&i);
-      }
+      this->ReadArray<ShapeTupleObj::index_type>(data.data(), ndim);
       result = ShapeTuple(std::move(data));
     } else {
       LOG(FATAL) << "ValueError: Object type is not supported in Disco calling 
convention: "
diff --git a/tests/python/disco/test_session.py 
b/tests/python/disco/test_session.py
index 2e5afe35f7..a2c0906f22 100644
--- a/tests/python/disco/test_session.py
+++ b/tests/python/disco/test_session.py
@@ -23,6 +23,7 @@ import numpy as np
 import tvm
 from tvm import relax as rx
 from tvm._ffi import register_func
+from tvm.runtime import ShapeTuple, String
 from tvm.runtime import disco as di
 from tvm.script import ir as I
 from tvm.script import relax as R
@@ -106,6 +107,42 @@ def test_string():
         assert result.debug_get_from_remote(i) == "hello_suffix"
 
 
+def test_string_obj():
+    num_workers = 4
+
+    @register_func("tests.disco.str_obj", override=True)
+    def my_str_func(x: String):  # pylint: disable=invalid-name
+        assert isinstance(x, String)
+        return String(x + "_suffix")
+
+    sess = di.ThreadedSession(num_workers=num_workers)
+    func: di.DPackedFunc = sess.get_global_func("tests.disco.str_obj")
+    result: di.DRef = func(String("hello"))
+
+    for i in range(num_workers):
+        value = result.debug_get_from_remote(i)
+        assert isinstance(value, String)
+        assert value == "hello_suffix"
+
+
+def test_shape_tuple():
+    num_workers = 4
+
+    @register_func("tests.disco.shape_tuple", override=True)
+    def my_str_func(x: ShapeTuple):  # pylint: disable=invalid-name
+        assert isinstance(x, ShapeTuple)
+        return ShapeTuple(list(x) + [4, 5])
+
+    sess = di.ThreadedSession(num_workers=num_workers)
+    func: di.DPackedFunc = sess.get_global_func("tests.disco.shape_tuple")
+    result: di.DRef = func(ShapeTuple([1, 2, 3]))
+
+    for i in range(num_workers):
+        value = result.debug_get_from_remote(i)
+        assert isinstance(value, ShapeTuple)
+        assert list(value) == [1, 2, 3, 4, 5]
+
+
 def test_vm_module():
     num_workers = 4
 
@@ -210,6 +247,8 @@ if __name__ == "__main__":
     test_int()
     test_float()
     test_string()
+    test_string_obj()
+    test_shape_tuple()
     test_ndarray()
     test_vm_module()
     test_vm_multi_func()

Reply via email to