This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b57d0b3d07 [Runtime][Disco] Fix session attribute storage, NVSHMEM 
build, and test gating (#19736)
b57d0b3d07 is described below

commit b57d0b3d07d9e97d8e4ce16106f33c1e37940191
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Jun 12 08:32:23 2026 -0400

    [Runtime][Disco] Fix session attribute storage, NVSHMEM build, and test 
gating (#19736)
    
    The tvm_ffi Object metaclass now gives every subclass `__slots__ = ()`,
    so the Disco Python wrappers can no longer store instance attributes and
    every session construction fails with AttributeError. Declare the
    attributes each
    wrapper actually stores as named slots, fix the NVSHMEM `dist_gemm.cu`
    so TVM builds with `USE_NVSHMEM = ON`, and gate the disco tests on the
    disco runtime being present so they skip cleanly on builds (e.g. the pip
    wheel) that report `USE_NCCL` / `USE_NVSHMEM = ON` without shipping it.
    
    ### Session attribute storage
    - `DPackedFunc` / `DModule`: `__slots__ = ("session",)`.
    - `Session`: `__slots__ = ("_cache", "_import_python_module")`
---
 python/tvm/runtime/disco/session.py            |  15 +-
 python/tvm/testing/utils.py                    |  15 ++
 src/runtime/extra/contrib/nvshmem/dist_gemm.cu |  18 +--
 tests/python/disco/test_loader.py              |  10 ++
 tests/python/disco/test_nvshmem.py             | 206 +++++++++++++++----------
 tests/python/disco/test_session.py             | 109 ++++++++-----
 6 files changed, 235 insertions(+), 138 deletions(-)

diff --git a/python/tvm/runtime/disco/session.py 
b/python/tvm/runtime/disco/session.py
index 08afbbfc80..0a81250900 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -78,6 +78,10 @@ class DRef(Object):
 class DPackedFunc(DRef):
     """A PackedFunc in a Disco session."""
 
+    # tvm_ffi Object subclasses cannot store Python attributes by default
+    # (the metaclass sets `__slots__ = ()`); list the field(s) we store here.
+    __slots__ = ("session",)
+
     def __init__(self, dref: DRef, session: "Session") -> None:
         self.__move_handle_from__(dref)
         self.session = session
@@ -89,6 +93,10 @@ class DPackedFunc(DRef):
 class DModule(DRef):
     """A Module in a Disco session."""
 
+    # tvm_ffi Object subclasses cannot store Python attributes by default
+    # (the metaclass sets `__slots__ = ()`); list the field(s) we store here.
+    __slots__ = ("session",)
+
     def __init__(self, dref: DRef, session: "Session") -> None:
         self.__move_handle_from__(dref)
         self.session = session
@@ -103,8 +111,13 @@ class Session(Object):
     """A Disco interactive session. It allows users to interact with the Disco 
command queue with
     various PackedFunc calling convention."""
 
+    # tvm_ffi Object subclasses cannot store Python attributes by default
+    # (the metaclass sets `__slots__ = ()`); list the fields we store here:
+    # the method-lookup cache and the lazily bound import helper.
+    __slots__ = ("_cache", "_import_python_module")
+
     def _get_cached_method(self, name: str) -> Callable:
-        if "_cache" not in self.__dict__:
+        if not hasattr(self, "_cache"):
             cache = self._cache = {}  # pylint: 
disable=attribute-defined-outside-init
         else:
             cache = self._cache
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 3b0dea1bb5..cfc5a357a3 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -910,6 +910,21 @@ requires_cublas = Feature("cublas", "cuBLAS", 
cmake_flag="USE_CUBLAS", parent_fe
 # Mark a test as requiring NCCL support
 requires_nccl = Feature("nccl", "NCCL", cmake_flag="USE_NCCL", 
parent_features="cuda")
 
+
+def _nvshmem_exists():
+    # Probe the runtime function rather than the USE_NVSHMEM cmake flag: the
+    # flag can be ON in builds that do not ship the disco NVSHMEM runtime.
+    return (
+        tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", 
allow_missing=True)
+        is not None
+    )
+
+
+# Mark a test as requiring NVSHMEM support
+requires_nvshmem = Feature(
+    "nvshmem", "NVSHMEM", run_time_check=_nvshmem_exists, 
parent_features="cuda"
+)
+
 # Mark a test as requiring the NVPTX compilation on the CUDA runtime
 requires_nvptx = Feature(
     "nvptx",
diff --git a/src/runtime/extra/contrib/nvshmem/dist_gemm.cu 
b/src/runtime/extra/contrib/nvshmem/dist_gemm.cu
index 512613f2ed..860cba6b52 100644
--- a/src/runtime/extra/contrib/nvshmem/dist_gemm.cu
+++ b/src/runtime/extra/contrib/nvshmem/dist_gemm.cu
@@ -37,7 +37,7 @@ void* get_pointer(Tensor data, ffi::Shape index) {
     offset *= data->shape[i];
     offset += index[i];
   }
-  return static_cast<void*>(ptr + offset * GetDataSize(1, data->dtype));
+  return static_cast<void*>(ptr + offset * ffi::GetDataSize(1, data->dtype));
 }
 
 void cuStreamWaitValue64Wrapper(TVMStreamHandle strm, void* addr, uint64_t 
expected) {
@@ -60,9 +60,7 @@ void copy_to_peer(void* dst, int dst_device, void* src, 
size_t size, TVMStreamHa
 
 TVMStreamHandle stream_create() {
   DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
-  if (worker == nullptr) {
-    LOG(FATAL) << "NVSHMEM stream creation failed: worker is not initialized";
-  }
+  TVM_FFI_ICHECK(worker != nullptr) << "NVSHMEM stream creation failed: worker 
is not initialized";
   cudaStream_t retval;
   CUDA_CALL(cudaStreamCreateWithFlags(&retval, cudaStreamNonBlocking));
   return static_cast<TVMStreamHandle>(retval);
@@ -70,9 +68,7 @@ TVMStreamHandle stream_create() {
 
 void stream_sync(TVMStreamHandle from_stream, TVMStreamHandle to_stream) {
   DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
-  if (worker == nullptr) {
-    LOG(FATAL) << "NVSHMEM stream sync failed: worker is not initialized";
-  }
+  TVM_FFI_ICHECK(worker != nullptr) << "NVSHMEM stream sync failed: worker is 
not initialized";
   auto f_sync_stream = 
tvm::ffi::Function::GetGlobalRequired("runtime.Device_StreamSyncFromTo");
   f_sync_stream(worker->default_device, reinterpret_cast<int64_t>(from_stream),
                 reinterpret_cast<int64_t>(to_stream));
@@ -91,9 +87,7 @@ void transfer_to_peers_reduce_scatter(Tensor semaphore, 
Tensor gemm_out, Tensor
                                       TVMStreamHandle stream, int32_t M, 
int32_t N, int32_t BLK_M,
                                       int32_t BLK_N, int32_t WORLD_SIZE) {
   DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
-  if (worker == nullptr) {
-    LOG(FATAL) << "NVSHMEM transfer to peer failed: worker is not initialized";
-  }
+  TVM_FFI_ICHECK(worker != nullptr) << "NVSHMEM transfer to peer failed: 
worker is not initialized";
   int my_rank = worker->worker_id;
   int LOCAL_M = M / WORLD_SIZE;
   for (int i = 0; i < WORLD_SIZE; i++) {
@@ -118,9 +112,7 @@ void transfer_to_peers_reduce_scatter(Tensor semaphore, 
Tensor gemm_out, Tensor
 void transfer_to_peers_all_gather(Tensor semaphore, Tensor A, Tensor ag_out, 
TVMStreamHandle stream,
                                   int32_t M, int32_t K, int32_t WORLD_SIZE) {
   DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
-  if (worker == nullptr) {
-    LOG(FATAL) << "NVSHMEM transfer to peer failed: worker is not initialized";
-  }
+  TVM_FFI_ICHECK(worker != nullptr) << "NVSHMEM transfer to peer failed: 
worker is not initialized";
   int my_rank = worker->worker_id;
   int LOCAL_M = M / WORLD_SIZE;
   for (int i = 0; i < WORLD_SIZE; i++) {
diff --git a/tests/python/disco/test_loader.py 
b/tests/python/disco/test_loader.py
index b709571219..290b9f401f 100644
--- a/tests/python/disco/test_loader.py
+++ b/tests/python/disco/test_loader.py
@@ -22,6 +22,7 @@ import json
 import tempfile
 
 import numpy as np
+import pytest
 from tvm_ffi import Shape, register_global_func
 
 import tvm
@@ -34,6 +35,15 @@ from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.target import Target
 
+# `runtime.disco.compiled_ccl` is registered together with the CCL runtime
+# functions, so its absence means the disco CCL runtime is not in this build.
+_compiled_ccl = tvm.get_global_func("runtime.disco.compiled_ccl", 
allow_missing=True)
+if _compiled_ccl is None or _compiled_ccl() != "nccl":
+    pytest.skip("Disco NCCL support is not available", allow_module_level=True)
+
+# All tests in this file shard across two GPUs.
+pytestmark = tvm.testing.requires_multi_gpu.marks()
+
 
 @register_global_func("tests.disco.shard_dim_0", override=True)
 def _shard_dim_0(src, num_shards, tgt):
diff --git a/tests/python/disco/test_nvshmem.py 
b/tests/python/disco/test_nvshmem.py
index 77b57a2c0b..5d70ccf6bd 100644
--- a/tests/python/disco/test_nvshmem.py
+++ b/tests/python/disco/test_nvshmem.py
@@ -14,19 +14,17 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# ruff: noqa: F401, F841
 """Basic tests for a Disco nvshmem support"""
 
 # pylint: disable=missing-docstring
 import multiprocessing
+import os
 import shutil
+import socket
 import subprocess
 import sys
 import tempfile
 import threading
-from collections.abc import Callable
-from multiprocessing import Process
-from typing import Any
 
 import numpy as np
 import pytest
@@ -34,50 +32,63 @@ from tvm_ffi import Shape
 
 import tvm
 import tvm.testing
-from tvm.exec import disco_worker as _  # pylint: disable=unused-import
 from tvm.runtime import disco as di
 from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tirx as T
 
-_SOCKET_SESSION_TESTER = None
+if di is None:
+    pytest.skip("disco runtime is not available", allow_module_level=True)
+
+pytestmark = tvm.testing.requires_nvshmem.marks()
 
 
-def get_free_port():
-    import socket
+_SOCKET_SESSION_TESTER = None
 
-    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    s.bind(("", 0))
-    port = s.getsockname()[1]
-    s.close()
+
+def _get_free_port():
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    sock.bind(("", 0))
+    port = sock.getsockname()[1]
+    sock.close()
     return port
 
 
 class SocketSessionTester:
-    def __init__(self, num_workers):
-        num_nodes = 2
-        num_groups = 1
+    """Run a disco SocketSession with one local node and remote nodes.
+
+    Each remote node is a `tvm.exec.disco_remote_socket_session` subprocess
+    launched with the current Python interpreter.
+    """
+
+    def __init__(self, num_workers, num_nodes=2, num_groups=1):
+        # Initialize the attributes used by __del__ first, so that teardown is
+        # safe even when __init__ raises below.
+        self.sess = None
+        self.remote_nodes = []
         assert num_workers % num_nodes == 0
         num_workers_per_node = num_workers // num_nodes
         server_host = "localhost"
-        server_port = get_free_port()
-        self.sess = None
+        server_port = _get_free_port()
+        server_exc = []
 
         def start_server():
-            self.sess = di.SocketSession(
-                num_nodes, num_workers_per_node, num_groups, server_host, 
server_port
-            )
+            try:
+                self.sess = di.SocketSession(
+                    num_nodes, num_workers_per_node, num_groups, server_host, 
server_port
+                )
+            except Exception as exc:  # pylint: disable=broad-except
+                server_exc.append(exc)
 
         thread = threading.Thread(target=start_server)
         thread.start()
 
         cmd = "tvm.exec.disco_remote_socket_session"
-        self.remote_nodes = []
         for _i in range(num_nodes - 1):
             self.remote_nodes.append(
                 subprocess.Popen(
                     [
-                        "python3",
+                        sys.executable,
                         "-m",
                         cmd,
                         server_host,
@@ -90,26 +101,73 @@ class SocketSessionTester:
             )
 
         thread.join()
+        if server_exc:
+            raise server_exc[0]
+
+    # Bound at class creation: module globals may already be cleared when
+    # __del__ runs during interpreter shutdown.
+    _TIMEOUT_EXPIRED = subprocess.TimeoutExpired
 
     def __del__(self):
-        if self.sess is not None:
-            self.sess.shutdown()
-            del self.sess
+        try:
+            # Shut down the session first so remote nodes can exit gracefully.
+            if self.sess is not None:
+                self.sess.shutdown()
+        finally:
+            for node in self.remote_nodes:
+                try:
+                    node.wait(timeout=10)
+                except self._TIMEOUT_EXPIRED:
+                    node.kill()
+                    node.wait()
 
 
 def create_socket_session(num_workers):
+    """Create a socket session backed by one local and one remote node.
+
+    The tester is kept alive in a module-level global so that the session
+    survives until the next call (or interpreter exit) replaces it.
+    """
     global _SOCKET_SESSION_TESTER
-    if _SOCKET_SESSION_TESTER is not None:
-        del _SOCKET_SESSION_TESTER
+    # Rebind (not `del`) so the global stays defined if the constructor raises.
+    _SOCKET_SESSION_TESTER = None
     _SOCKET_SESSION_TESTER = SocketSessionTester(num_workers)
     assert _SOCKET_SESSION_TESTER.sess is not None
     return _SOCKET_SESSION_TESTER.sess
 
 
-def test_nvshmem_init_finalize(session_kind: di.Session, num_workers: int):
-    if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is 
None:
-        return
+_all_session_kinds = [di.ProcessSession, create_socket_session]
+_all_num_workers = [2, 4]
+
+_SUBPROCESS_TIMEOUT_SEC = 600
+
+
+def _run_in_fresh_process(target, *args):
+    """Run a test body in a freshly spawned process.
 
+    After the first call to `nvshmem_init`, a subsequent call to `nvshmem_init`
+    or `nvshmem_init_thread` in the same program results in undefined behavior,
+    and worker-0 of a Disco session lives in the calling process. So each test
+    body must run in its own process. The 'spawn' start method avoids
+    inheriting CUDA state from this process.
+    """
+    proc = multiprocessing.get_context("spawn").Process(target=target, 
args=args)
+    proc.start()
+    proc.join(timeout=_SUBPROCESS_TIMEOUT_SEC)
+    if proc.is_alive():
+        proc.kill()
+        proc.join()
+        pytest.fail(f"{target.__name__}{args} timed out after 
{_SUBPROCESS_TIMEOUT_SEC} seconds")
+    assert proc.exitcode == 0, f"{target.__name__}{args} failed with exit code 
{proc.exitcode}"
+
+
+def _require_cuda_devices(num_workers):
+    # Each nvshmem worker binds its own CUDA device (cudaSetDevice(worker_id)).
+    if not all(tvm.cuda(i).exist for i in range(num_workers)):
+        pytest.skip(f"Requires {num_workers} CUDA devices")
+
+
+def _init_finalize(session_kind, num_workers):
     sess = session_kind(num_workers=num_workers)
     f_init_nvshmem_uid = 
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
     uid = f_init_nvshmem_uid()
@@ -121,10 +179,7 @@ def test_nvshmem_init_finalize(session_kind: di.Session, 
num_workers: int):
     sess.sync_worker_0()
 
 
-def test_nvshmem_empty(session_kind: di.Session, num_workers: int):
-    if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is 
None:
-        return
-
+def _empty(session_kind, num_workers):
     device = tvm.cuda()
     sess = session_kind(num_workers=num_workers)
     f_init_nvshmem_uid = 
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
@@ -133,18 +188,29 @@ def test_nvshmem_empty(session_kind: di.Session, 
num_workers: int):
     init_dfunc(uid, num_workers, 0)
     sess.sync_worker_0()
     empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty")
-    a = empty_dfunc(Shape((32, 64)), "float32", device)
-    b = empty_dfunc(Shape((64, 32)), "float32", device)
+    _a = empty_dfunc(Shape((32, 64)), "float32", device)
+    _b = empty_dfunc(Shape((64, 32)), "float32", device)
     sess.sync_worker_0()
     finalize_dfunc = 
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
     finalize_dfunc()
     sess.sync_worker_0()
 
 
-def test_nvshmem_compile():
-    if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is 
None:
-        return
[email protected]("session_kind", _all_session_kinds)
[email protected]("num_workers", _all_num_workers)
+def test_nvshmem_init_finalize(session_kind, num_workers: int):
+    _require_cuda_devices(num_workers)
+    _run_in_fresh_process(_init_finalize, session_kind, num_workers)
+
+
[email protected]("session_kind", _all_session_kinds)
[email protected]("num_workers", _all_num_workers)
+def test_nvshmem_empty(session_kind, num_workers: int):
+    _require_cuda_devices(num_workers)
+    _run_in_fresh_process(_empty, session_kind, num_workers)
+
 
+def _compile():
     num_workers = 2
     sess = di.ProcessSession(num_workers=num_workers)
 
@@ -194,6 +260,11 @@ def test_nvshmem_compile():
         shutil.rmtree(tmpdir, ignore_errors=True)
 
 
+def test_nvshmem_compile():
+    _require_cuda_devices(2)
+    _run_in_fresh_process(_compile)
+
+
 NVSHMEM_QUERY_KERNEL_SOURCE = """
 #include <nvshmem.h>
 
@@ -204,10 +275,12 @@ extern "C" __global__ void nvshmem_query_kernel(int* 
my_pe_out, int* n_pes_out)
 """
 
 
-def _test_nvshmem_kernel_compile_impl():
-    """Test compiling and running a kernel that calls NVSHMEM functions"""
-    if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is 
None:
-        return
+def _kernel_compile(compile_mode):
+    """Compile and run a kernel that calls NVSHMEM functions.
+
+    Runs in a fresh process, so setting the env var is safe.
+    """
+    os.environ["TVM_CUDA_COMPILE_MODE"] = compile_mode
 
     num_workers = 2
     sess = di.ProcessSession(num_workers=num_workers)
@@ -288,55 +361,20 @@ def _test_nvshmem_kernel_compile_impl():
 
 def test_nvshmem_kernel_compile_nvcc():
     """Test NVSHMEM kernel compilation with nvcc."""
-    # Since this test runs in a separate process, we can safely set the env var
-    import os
-
-    os.environ["TVM_CUDA_COMPILE_MODE"] = "nvcc"
-    _test_nvshmem_kernel_compile_impl()
+    _require_cuda_devices(2)
+    _run_in_fresh_process(_kernel_compile, "nvcc")
 
 
 def test_nvshmem_kernel_compile_nvrtc():
     """Test NVSHMEM kernel compilation with nvrtc."""
+    _require_cuda_devices(2)
     try:
-        from cuda.bindings import nvrtc
+        from cuda.bindings import nvrtc  # noqa: F401
     except ImportError:
         pytest.skip("cuda-python not available, skipping nvrtc test")
 
-    # Since this test runs in a separate process, we can safely set the env var
-    import os
-
-    os.environ["TVM_CUDA_COMPILE_MODE"] = "nvrtc"
-    _test_nvshmem_kernel_compile_impl()
+    _run_in_fresh_process(_kernel_compile, "nvrtc")
 
 
 if __name__ == "__main__":
-    # After the first call to `nvshmem_init`, a subsequent call to 
`nvshmem_init`
-    # or `nvshmem_init_thread` in the same program results in undefined 
behavior.
-    # So we always create a new process to run the test. Then no repeated 
nvshmem
-    # init happens in the same process, since the worker0 may share the same 
process.
-
-    # Use 'spawn' start method to avoid inheriting CUDA state from parent 
process
-    # 'fork' (default on Linux) can cause issues with CUDA contexts in child 
processes
-    multiprocessing.set_start_method("spawn", force=True)
-
-    for session_kind in [create_socket_session, di.ProcessSession]:
-        for num_workers in [2, 4]:
-            for test_func in [test_nvshmem_init_finalize, test_nvshmem_empty]:
-                p = Process(target=test_func, args=[session_kind, num_workers])
-                p.start()
-                p.join()
-                # Ensure the process finished successfully
-                assert p.exitcode == 0, (
-                    f"Test {test_func.__name__} failed with exit code 
{p.exitcode}"
-                )
-
-    p = Process(target=test_nvshmem_compile)
-    p.start()
-    p.join()
-    assert p.exitcode == 0, f"Test test_nvshmem_compile failed with exit code 
{p.exitcode}"
-
-    for test_func in [test_nvshmem_kernel_compile_nvcc, 
test_nvshmem_kernel_compile_nvrtc]:
-        p = Process(target=test_func)
-        p.start()
-        p.join()
-        assert p.exitcode == 0, f"Test {test_func.__name__} failed with exit 
code {p.exitcode}"
+    tvm.testing.main()
diff --git a/tests/python/disco/test_session.py 
b/tests/python/disco/test_session.py
index 7360ae9a6a..1f482b9ee5 100644
--- a/tests/python/disco/test_session.py
+++ b/tests/python/disco/test_session.py
@@ -14,10 +14,10 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# ruff: noqa: F401
 """Basic tests for a Disco session"""
 
 # pylint: disable=missing-docstring
+import socket
 import subprocess
 import sys
 import tempfile
@@ -30,66 +30,64 @@ from tvm_ffi.core import String
 
 import tvm
 import tvm.testing
-from tvm import relax as rx
-from tvm.exec import disco_worker as _  # pylint: disable=unused-import
+
+# Imported for the side effect of registering the tests.disco.* worker 
functions.
+from tvm.exec import disco_worker as _  # noqa: F401  # pylint: 
disable=unused-import
 from tvm.runtime import disco as di
 from tvm.script import ir as I
 from tvm.script import relax as R
 from tvm.script import tirx as T
 
-
-def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device):
-    x_array = sess.empty(np_array.shape, "float32", device=device)
-    host_array = tvm.runtime.tensor(np_array, device=device)
-    sess.copy_to_worker_0(host_array, x_array)
-    return x_array
-
-
-def _numpy_from_worker_0(sess: di.Session, remote_array, shape, dtype):
-    host_array = tvm.runtime.empty(shape, dtype, device=tvm.cpu())
-    sess.copy_from_worker_0(host_array, remote_array)
-    sess.sync_worker_0()
-    return host_array.numpy()
+if di is None:
+    pytest.skip("disco runtime is not available", allow_module_level=True)
 
 
 _SOCKET_SESSION_TESTER = None
 
 
-def get_free_port():
-    import socket
-
-    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
-    s.bind(("", 0))
-    port = s.getsockname()[1]
-    s.close()
+def _get_free_port():
+    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+    sock.bind(("", 0))
+    port = sock.getsockname()[1]
+    sock.close()
     return port
 
 
 class SocketSessionTester:
-    def __init__(self, num_workers):
-        num_nodes = 2
-        num_groups = 1
+    """Run a disco SocketSession with one local node and remote nodes.
+
+    Each remote node is a `tvm.exec.disco_remote_socket_session` subprocess
+    launched with the current Python interpreter.
+    """
+
+    def __init__(self, num_workers, num_nodes=2, num_groups=1):
+        # Initialize the attributes used by __del__ first, so that teardown is
+        # safe even when __init__ raises below.
+        self.sess = None
+        self.remote_nodes = []
         assert num_workers % num_nodes == 0
         num_workers_per_node = num_workers // num_nodes
         server_host = "localhost"
-        server_port = get_free_port()
-        self.sess = None
+        server_port = _get_free_port()
+        server_exc = []
 
         def start_server():
-            self.sess = di.SocketSession(
-                num_nodes, num_workers_per_node, num_groups, server_host, 
server_port
-            )
+            try:
+                self.sess = di.SocketSession(
+                    num_nodes, num_workers_per_node, num_groups, server_host, 
server_port
+                )
+            except Exception as exc:  # pylint: disable=broad-except
+                server_exc.append(exc)
 
         thread = threading.Thread(target=start_server)
         thread.start()
 
         cmd = "tvm.exec.disco_remote_socket_session"
-        self.remote_nodes = []
         for _i in range(num_nodes - 1):
             self.remote_nodes.append(
                 subprocess.Popen(
                     [
-                        "python3",
+                        sys.executable,
                         "-m",
                         cmd,
                         server_host,
@@ -102,24 +100,55 @@ class SocketSessionTester:
             )
 
         thread.join()
+        if server_exc:
+            raise server_exc[0]
+
+    # Bound at class creation: module globals may already be cleared when
+    # __del__ runs during interpreter shutdown.
+    _TIMEOUT_EXPIRED = subprocess.TimeoutExpired
 
     def __del__(self):
-        for node in self.remote_nodes:
-            node.kill()
-        if self.sess is not None:
-            self.sess.shutdown()
-            del self.sess
+        try:
+            # Shut down the session first so remote nodes can exit gracefully.
+            if self.sess is not None:
+                self.sess.shutdown()
+        finally:
+            for node in self.remote_nodes:
+                try:
+                    node.wait(timeout=10)
+                except self._TIMEOUT_EXPIRED:
+                    node.kill()
+                    node.wait()
 
 
 def create_socket_session(num_workers):
+    """Create a socket session backed by one local and one remote node.
+
+    The tester is kept alive in a module-level global so that the session
+    survives until the next call (or interpreter exit) replaces it.
+    """
     global _SOCKET_SESSION_TESTER
-    if _SOCKET_SESSION_TESTER is not None:
-        del _SOCKET_SESSION_TESTER
+    # Rebind (not `del`) so the global stays defined if the constructor raises.
+    _SOCKET_SESSION_TESTER = None
     _SOCKET_SESSION_TESTER = SocketSessionTester(num_workers)
     assert _SOCKET_SESSION_TESTER.sess is not None
     return _SOCKET_SESSION_TESTER.sess
 
 
+def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device):
+    x_array = sess.empty(np_array.shape, "float32", device=device)
+    host_array = tvm.runtime.tensor(np_array, device=device)
+    sess.copy_to_worker_0(host_array, x_array)
+    return x_array
+
+
+def _numpy_from_worker_0(sess: di.Session, remote_array, shape, dtype):
+    host_array = tvm.runtime.empty(shape, dtype, device=tvm.cpu())
+    sess.copy_from_worker_0(host_array, remote_array)
+    sess.sync_worker_0()
+    return host_array.numpy()
+
+
 _all_session_kinds = [di.ThreadedSession, di.ProcessSession, 
create_socket_session]
 
 

Reply via email to