This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b57d0b3d07 [Runtime][Disco] Fix session attribute storage, NVSHMEM
build, and test gating (#19736)
b57d0b3d07 is described below
commit b57d0b3d07d9e97d8e4ce16106f33c1e37940191
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Jun 12 08:32:23 2026 -0400
[Runtime][Disco] Fix session attribute storage, NVSHMEM build, and test
gating (#19736)
The tvm_ffi Object metaclass now gives every subclass `__slots__ = ()`,
so the Disco Python wrappers can no longer store instance attributes and
every session construction fails with AttributeError. Declare the
attributes each
wrapper actually stores as named slots, fix the NVSHMEM `dist_gemm.cu`
so TVM builds with `USE_NVSHMEM = ON`, and gate the disco tests on the
disco runtime being present so they skip cleanly on builds (e.g. the pip
wheel) that report `USE_NCCL` / `USE_NVSHMEM = ON` without shipping it.
### Session attribute storage
- `DPackedFunc` / `DModule`: `__slots__ = ("session",)`.
- `Session`: `__slots__ = ("_cache", "_import_python_module")`
---
python/tvm/runtime/disco/session.py | 15 +-
python/tvm/testing/utils.py | 15 ++
src/runtime/extra/contrib/nvshmem/dist_gemm.cu | 18 +--
tests/python/disco/test_loader.py | 10 ++
tests/python/disco/test_nvshmem.py | 206 +++++++++++++++----------
tests/python/disco/test_session.py | 109 ++++++++-----
6 files changed, 235 insertions(+), 138 deletions(-)
diff --git a/python/tvm/runtime/disco/session.py
b/python/tvm/runtime/disco/session.py
index 08afbbfc80..0a81250900 100644
--- a/python/tvm/runtime/disco/session.py
+++ b/python/tvm/runtime/disco/session.py
@@ -78,6 +78,10 @@ class DRef(Object):
class DPackedFunc(DRef):
"""A PackedFunc in a Disco session."""
+ # tvm_ffi Object subclasses cannot store Python attributes by default
+ # (the metaclass sets `__slots__ = ()`); list the field(s) we store here.
+ __slots__ = ("session",)
+
def __init__(self, dref: DRef, session: "Session") -> None:
self.__move_handle_from__(dref)
self.session = session
@@ -89,6 +93,10 @@ class DPackedFunc(DRef):
class DModule(DRef):
"""A Module in a Disco session."""
+ # tvm_ffi Object subclasses cannot store Python attributes by default
+ # (the metaclass sets `__slots__ = ()`); list the field(s) we store here.
+ __slots__ = ("session",)
+
def __init__(self, dref: DRef, session: "Session") -> None:
self.__move_handle_from__(dref)
self.session = session
@@ -103,8 +111,13 @@ class Session(Object):
"""A Disco interactive session. It allows users to interact with the Disco
command queue with
various PackedFunc calling convention."""
+ # tvm_ffi Object subclasses cannot store Python attributes by default
+ # (the metaclass sets `__slots__ = ()`); list the fields we store here:
+ # the method-lookup cache and the lazily bound import helper.
+ __slots__ = ("_cache", "_import_python_module")
+
def _get_cached_method(self, name: str) -> Callable:
- if "_cache" not in self.__dict__:
+ if not hasattr(self, "_cache"):
cache = self._cache = {} # pylint:
disable=attribute-defined-outside-init
else:
cache = self._cache
diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 3b0dea1bb5..cfc5a357a3 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -910,6 +910,21 @@ requires_cublas = Feature("cublas", "cuBLAS",
cmake_flag="USE_CUBLAS", parent_fe
# Mark a test as requiring NCCL support
requires_nccl = Feature("nccl", "NCCL", cmake_flag="USE_NCCL",
parent_features="cuda")
+
+def _nvshmem_exists():
+ # Probe the runtime function rather than the USE_NVSHMEM cmake flag: the
+ # flag can be ON in builds that do not ship the disco NVSHMEM runtime.
+ return (
+ tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid",
allow_missing=True)
+ is not None
+ )
+
+
+# Mark a test as requiring NVSHMEM support
+requires_nvshmem = Feature(
+ "nvshmem", "NVSHMEM", run_time_check=_nvshmem_exists,
parent_features="cuda"
+)
+
# Mark a test as requiring the NVPTX compilation on the CUDA runtime
requires_nvptx = Feature(
"nvptx",
diff --git a/src/runtime/extra/contrib/nvshmem/dist_gemm.cu
b/src/runtime/extra/contrib/nvshmem/dist_gemm.cu
index 512613f2ed..860cba6b52 100644
--- a/src/runtime/extra/contrib/nvshmem/dist_gemm.cu
+++ b/src/runtime/extra/contrib/nvshmem/dist_gemm.cu
@@ -37,7 +37,7 @@ void* get_pointer(Tensor data, ffi::Shape index) {
offset *= data->shape[i];
offset += index[i];
}
- return static_cast<void*>(ptr + offset * GetDataSize(1, data->dtype));
+ return static_cast<void*>(ptr + offset * ffi::GetDataSize(1, data->dtype));
}
void cuStreamWaitValue64Wrapper(TVMStreamHandle strm, void* addr, uint64_t
expected) {
@@ -60,9 +60,7 @@ void copy_to_peer(void* dst, int dst_device, void* src,
size_t size, TVMStreamHa
TVMStreamHandle stream_create() {
DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
- if (worker == nullptr) {
- LOG(FATAL) << "NVSHMEM stream creation failed: worker is not initialized";
- }
+ TVM_FFI_ICHECK(worker != nullptr) << "NVSHMEM stream creation failed: worker
is not initialized";
cudaStream_t retval;
CUDA_CALL(cudaStreamCreateWithFlags(&retval, cudaStreamNonBlocking));
return static_cast<TVMStreamHandle>(retval);
@@ -70,9 +68,7 @@ TVMStreamHandle stream_create() {
void stream_sync(TVMStreamHandle from_stream, TVMStreamHandle to_stream) {
DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
- if (worker == nullptr) {
- LOG(FATAL) << "NVSHMEM stream sync failed: worker is not initialized";
- }
+ TVM_FFI_ICHECK(worker != nullptr) << "NVSHMEM stream sync failed: worker is
not initialized";
auto f_sync_stream =
tvm::ffi::Function::GetGlobalRequired("runtime.Device_StreamSyncFromTo");
f_sync_stream(worker->default_device, reinterpret_cast<int64_t>(from_stream),
reinterpret_cast<int64_t>(to_stream));
@@ -91,9 +87,7 @@ void transfer_to_peers_reduce_scatter(Tensor semaphore,
Tensor gemm_out, Tensor
TVMStreamHandle stream, int32_t M,
int32_t N, int32_t BLK_M,
int32_t BLK_N, int32_t WORLD_SIZE) {
DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
- if (worker == nullptr) {
- LOG(FATAL) << "NVSHMEM transfer to peer failed: worker is not initialized";
- }
+ TVM_FFI_ICHECK(worker != nullptr) << "NVSHMEM transfer to peer failed:
worker is not initialized";
int my_rank = worker->worker_id;
int LOCAL_M = M / WORLD_SIZE;
for (int i = 0; i < WORLD_SIZE; i++) {
@@ -118,9 +112,7 @@ void transfer_to_peers_reduce_scatter(Tensor semaphore,
Tensor gemm_out, Tensor
void transfer_to_peers_all_gather(Tensor semaphore, Tensor A, Tensor ag_out,
TVMStreamHandle stream,
int32_t M, int32_t K, int32_t WORLD_SIZE) {
DiscoWorker* worker = ThreadLocalDiscoWorker::Get()->worker;
- if (worker == nullptr) {
- LOG(FATAL) << "NVSHMEM transfer to peer failed: worker is not initialized";
- }
+ TVM_FFI_ICHECK(worker != nullptr) << "NVSHMEM transfer to peer failed:
worker is not initialized";
int my_rank = worker->worker_id;
int LOCAL_M = M / WORLD_SIZE;
for (int i = 0; i < WORLD_SIZE; i++) {
diff --git a/tests/python/disco/test_loader.py
b/tests/python/disco/test_loader.py
index b709571219..290b9f401f 100644
--- a/tests/python/disco/test_loader.py
+++ b/tests/python/disco/test_loader.py
@@ -22,6 +22,7 @@ import json
import tempfile
import numpy as np
+import pytest
from tvm_ffi import Shape, register_global_func
import tvm
@@ -34,6 +35,15 @@ from tvm.script import ir as I
from tvm.script import relax as R
from tvm.target import Target
+# `runtime.disco.compiled_ccl` is registered together with the CCL runtime
+# functions, so its absence means the disco CCL runtime is not in this build.
+_compiled_ccl = tvm.get_global_func("runtime.disco.compiled_ccl",
allow_missing=True)
+if _compiled_ccl is None or _compiled_ccl() != "nccl":
+ pytest.skip("Disco NCCL support is not available", allow_module_level=True)
+
+# All tests in this file shard across two GPUs.
+pytestmark = tvm.testing.requires_multi_gpu.marks()
+
@register_global_func("tests.disco.shard_dim_0", override=True)
def _shard_dim_0(src, num_shards, tgt):
diff --git a/tests/python/disco/test_nvshmem.py
b/tests/python/disco/test_nvshmem.py
index 77b57a2c0b..5d70ccf6bd 100644
--- a/tests/python/disco/test_nvshmem.py
+++ b/tests/python/disco/test_nvshmem.py
@@ -14,19 +14,17 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# ruff: noqa: F401, F841
"""Basic tests for a Disco nvshmem support"""
# pylint: disable=missing-docstring
import multiprocessing
+import os
import shutil
+import socket
import subprocess
import sys
import tempfile
import threading
-from collections.abc import Callable
-from multiprocessing import Process
-from typing import Any
import numpy as np
import pytest
@@ -34,50 +32,63 @@ from tvm_ffi import Shape
import tvm
import tvm.testing
-from tvm.exec import disco_worker as _ # pylint: disable=unused-import
from tvm.runtime import disco as di
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tirx as T
-_SOCKET_SESSION_TESTER = None
+if di is None:
+ pytest.skip("disco runtime is not available", allow_module_level=True)
+
+pytestmark = tvm.testing.requires_nvshmem.marks()
-def get_free_port():
- import socket
+_SOCKET_SESSION_TESTER = None
- s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- s.bind(("", 0))
- port = s.getsockname()[1]
- s.close()
+
+def _get_free_port():
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.bind(("", 0))
+ port = sock.getsockname()[1]
+ sock.close()
return port
class SocketSessionTester:
- def __init__(self, num_workers):
- num_nodes = 2
- num_groups = 1
+ """Run a disco SocketSession with one local node and remote nodes.
+
+ Each remote node is a `tvm.exec.disco_remote_socket_session` subprocess
+ launched with the current Python interpreter.
+ """
+
+ def __init__(self, num_workers, num_nodes=2, num_groups=1):
+ # Initialize the attributes used by __del__ first, so that teardown is
+ # safe even when __init__ raises below.
+ self.sess = None
+ self.remote_nodes = []
assert num_workers % num_nodes == 0
num_workers_per_node = num_workers // num_nodes
server_host = "localhost"
- server_port = get_free_port()
- self.sess = None
+ server_port = _get_free_port()
+ server_exc = []
def start_server():
- self.sess = di.SocketSession(
- num_nodes, num_workers_per_node, num_groups, server_host,
server_port
- )
+ try:
+ self.sess = di.SocketSession(
+ num_nodes, num_workers_per_node, num_groups, server_host,
server_port
+ )
+ except Exception as exc: # pylint: disable=broad-except
+ server_exc.append(exc)
thread = threading.Thread(target=start_server)
thread.start()
cmd = "tvm.exec.disco_remote_socket_session"
- self.remote_nodes = []
for _i in range(num_nodes - 1):
self.remote_nodes.append(
subprocess.Popen(
[
- "python3",
+ sys.executable,
"-m",
cmd,
server_host,
@@ -90,26 +101,73 @@ class SocketSessionTester:
)
thread.join()
+ if server_exc:
+ raise server_exc[0]
+
+ # Bound at class creation: module globals may already be cleared when
+ # __del__ runs during interpreter shutdown.
+ _TIMEOUT_EXPIRED = subprocess.TimeoutExpired
def __del__(self):
- if self.sess is not None:
- self.sess.shutdown()
- del self.sess
+ try:
+ # Shut down the session first so remote nodes can exit gracefully.
+ if self.sess is not None:
+ self.sess.shutdown()
+ finally:
+ for node in self.remote_nodes:
+ try:
+ node.wait(timeout=10)
+ except self._TIMEOUT_EXPIRED:
+ node.kill()
+ node.wait()
def create_socket_session(num_workers):
+ """Create a socket session backed by one local and one remote node.
+
+ The tester is kept alive in a module-level global so that the session
+ survives until the next call (or interpreter exit) replaces it.
+ """
global _SOCKET_SESSION_TESTER
- if _SOCKET_SESSION_TESTER is not None:
- del _SOCKET_SESSION_TESTER
+ # Rebind (not `del`) so the global stays defined if the constructor raises.
+ _SOCKET_SESSION_TESTER = None
_SOCKET_SESSION_TESTER = SocketSessionTester(num_workers)
assert _SOCKET_SESSION_TESTER.sess is not None
return _SOCKET_SESSION_TESTER.sess
-def test_nvshmem_init_finalize(session_kind: di.Session, num_workers: int):
- if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is
None:
- return
+_all_session_kinds = [di.ProcessSession, create_socket_session]
+_all_num_workers = [2, 4]
+
+_SUBPROCESS_TIMEOUT_SEC = 600
+
+
+def _run_in_fresh_process(target, *args):
+ """Run a test body in a freshly spawned process.
+ After the first call to `nvshmem_init`, a subsequent call to `nvshmem_init`
+ or `nvshmem_init_thread` in the same program results in undefined behavior,
+ and worker-0 of a Disco session lives in the calling process. So each test
+ body must run in its own process. The 'spawn' start method avoids
+ inheriting CUDA state from this process.
+ """
+ proc = multiprocessing.get_context("spawn").Process(target=target,
args=args)
+ proc.start()
+ proc.join(timeout=_SUBPROCESS_TIMEOUT_SEC)
+ if proc.is_alive():
+ proc.kill()
+ proc.join()
+ pytest.fail(f"{target.__name__}{args} timed out after
{_SUBPROCESS_TIMEOUT_SEC} seconds")
+ assert proc.exitcode == 0, f"{target.__name__}{args} failed with exit code
{proc.exitcode}"
+
+
+def _require_cuda_devices(num_workers):
+ # Each nvshmem worker binds its own CUDA device (cudaSetDevice(worker_id)).
+ if not all(tvm.cuda(i).exist for i in range(num_workers)):
+ pytest.skip(f"Requires {num_workers} CUDA devices")
+
+
+def _init_finalize(session_kind, num_workers):
sess = session_kind(num_workers=num_workers)
f_init_nvshmem_uid =
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
uid = f_init_nvshmem_uid()
@@ -121,10 +179,7 @@ def test_nvshmem_init_finalize(session_kind: di.Session,
num_workers: int):
sess.sync_worker_0()
-def test_nvshmem_empty(session_kind: di.Session, num_workers: int):
- if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is
None:
- return
-
+def _empty(session_kind, num_workers):
device = tvm.cuda()
sess = session_kind(num_workers=num_workers)
f_init_nvshmem_uid =
tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid")
@@ -133,18 +188,29 @@ def test_nvshmem_empty(session_kind: di.Session,
num_workers: int):
init_dfunc(uid, num_workers, 0)
sess.sync_worker_0()
empty_dfunc = sess.get_global_func("runtime.disco.nvshmem.empty")
- a = empty_dfunc(Shape((32, 64)), "float32", device)
- b = empty_dfunc(Shape((64, 32)), "float32", device)
+ _a = empty_dfunc(Shape((32, 64)), "float32", device)
+ _b = empty_dfunc(Shape((64, 32)), "float32", device)
sess.sync_worker_0()
finalize_dfunc =
sess.get_global_func("runtime.disco.nvshmem.finalize_nvshmem")
finalize_dfunc()
sess.sync_worker_0()
-def test_nvshmem_compile():
- if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is
None:
- return
[email protected]("session_kind", _all_session_kinds)
[email protected]("num_workers", _all_num_workers)
+def test_nvshmem_init_finalize(session_kind, num_workers: int):
+ _require_cuda_devices(num_workers)
+ _run_in_fresh_process(_init_finalize, session_kind, num_workers)
+
+
[email protected]("session_kind", _all_session_kinds)
[email protected]("num_workers", _all_num_workers)
+def test_nvshmem_empty(session_kind, num_workers: int):
+ _require_cuda_devices(num_workers)
+ _run_in_fresh_process(_empty, session_kind, num_workers)
+
+def _compile():
num_workers = 2
sess = di.ProcessSession(num_workers=num_workers)
@@ -194,6 +260,11 @@ def test_nvshmem_compile():
shutil.rmtree(tmpdir, ignore_errors=True)
+def test_nvshmem_compile():
+ _require_cuda_devices(2)
+ _run_in_fresh_process(_compile)
+
+
NVSHMEM_QUERY_KERNEL_SOURCE = """
#include <nvshmem.h>
@@ -204,10 +275,12 @@ extern "C" __global__ void nvshmem_query_kernel(int*
my_pe_out, int* n_pes_out)
"""
-def _test_nvshmem_kernel_compile_impl():
- """Test compiling and running a kernel that calls NVSHMEM functions"""
- if tvm.get_global_func("runtime.disco.nvshmem.init_nvshmem_uid", True) is
None:
- return
+def _kernel_compile(compile_mode):
+ """Compile and run a kernel that calls NVSHMEM functions.
+
+ Runs in a fresh process, so setting the env var is safe.
+ """
+ os.environ["TVM_CUDA_COMPILE_MODE"] = compile_mode
num_workers = 2
sess = di.ProcessSession(num_workers=num_workers)
@@ -288,55 +361,20 @@ def _test_nvshmem_kernel_compile_impl():
def test_nvshmem_kernel_compile_nvcc():
"""Test NVSHMEM kernel compilation with nvcc."""
- # Since this test runs in a separate process, we can safely set the env var
- import os
-
- os.environ["TVM_CUDA_COMPILE_MODE"] = "nvcc"
- _test_nvshmem_kernel_compile_impl()
+ _require_cuda_devices(2)
+ _run_in_fresh_process(_kernel_compile, "nvcc")
def test_nvshmem_kernel_compile_nvrtc():
"""Test NVSHMEM kernel compilation with nvrtc."""
+ _require_cuda_devices(2)
try:
- from cuda.bindings import nvrtc
+ from cuda.bindings import nvrtc # noqa: F401
except ImportError:
pytest.skip("cuda-python not available, skipping nvrtc test")
- # Since this test runs in a separate process, we can safely set the env var
- import os
-
- os.environ["TVM_CUDA_COMPILE_MODE"] = "nvrtc"
- _test_nvshmem_kernel_compile_impl()
+ _run_in_fresh_process(_kernel_compile, "nvrtc")
if __name__ == "__main__":
- # After the first call to `nvshmem_init`, a subsequent call to
`nvshmem_init`
- # or `nvshmem_init_thread` in the same program results in undefined
behavior.
- # So we always create a new process to run the test. Then no repeated
nvshmem
- # init happens in the same process, since the worker0 may share the same
process.
-
- # Use 'spawn' start method to avoid inheriting CUDA state from parent
process
- # 'fork' (default on Linux) can cause issues with CUDA contexts in child
processes
- multiprocessing.set_start_method("spawn", force=True)
-
- for session_kind in [create_socket_session, di.ProcessSession]:
- for num_workers in [2, 4]:
- for test_func in [test_nvshmem_init_finalize, test_nvshmem_empty]:
- p = Process(target=test_func, args=[session_kind, num_workers])
- p.start()
- p.join()
- # Ensure the process finished successfully
- assert p.exitcode == 0, (
- f"Test {test_func.__name__} failed with exit code
{p.exitcode}"
- )
-
- p = Process(target=test_nvshmem_compile)
- p.start()
- p.join()
- assert p.exitcode == 0, f"Test test_nvshmem_compile failed with exit code
{p.exitcode}"
-
- for test_func in [test_nvshmem_kernel_compile_nvcc,
test_nvshmem_kernel_compile_nvrtc]:
- p = Process(target=test_func)
- p.start()
- p.join()
- assert p.exitcode == 0, f"Test {test_func.__name__} failed with exit
code {p.exitcode}"
+ tvm.testing.main()
diff --git a/tests/python/disco/test_session.py
b/tests/python/disco/test_session.py
index 7360ae9a6a..1f482b9ee5 100644
--- a/tests/python/disco/test_session.py
+++ b/tests/python/disco/test_session.py
@@ -14,10 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# ruff: noqa: F401
"""Basic tests for a Disco session"""
# pylint: disable=missing-docstring
+import socket
import subprocess
import sys
import tempfile
@@ -30,66 +30,64 @@ from tvm_ffi.core import String
import tvm
import tvm.testing
-from tvm import relax as rx
-from tvm.exec import disco_worker as _ # pylint: disable=unused-import
+
+# Imported for the side effect of registering the tests.disco.* worker
functions.
+from tvm.exec import disco_worker as _ # noqa: F401 # pylint:
disable=unused-import
from tvm.runtime import disco as di
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tirx as T
-
-def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device):
- x_array = sess.empty(np_array.shape, "float32", device=device)
- host_array = tvm.runtime.tensor(np_array, device=device)
- sess.copy_to_worker_0(host_array, x_array)
- return x_array
-
-
-def _numpy_from_worker_0(sess: di.Session, remote_array, shape, dtype):
- host_array = tvm.runtime.empty(shape, dtype, device=tvm.cpu())
- sess.copy_from_worker_0(host_array, remote_array)
- sess.sync_worker_0()
- return host_array.numpy()
+if di is None:
+ pytest.skip("disco runtime is not available", allow_module_level=True)
_SOCKET_SESSION_TESTER = None
-def get_free_port():
- import socket
-
- s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- s.bind(("", 0))
- port = s.getsockname()[1]
- s.close()
+def _get_free_port():
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.bind(("", 0))
+ port = sock.getsockname()[1]
+ sock.close()
return port
class SocketSessionTester:
- def __init__(self, num_workers):
- num_nodes = 2
- num_groups = 1
+ """Run a disco SocketSession with one local node and remote nodes.
+
+ Each remote node is a `tvm.exec.disco_remote_socket_session` subprocess
+ launched with the current Python interpreter.
+ """
+
+ def __init__(self, num_workers, num_nodes=2, num_groups=1):
+ # Initialize the attributes used by __del__ first, so that teardown is
+ # safe even when __init__ raises below.
+ self.sess = None
+ self.remote_nodes = []
assert num_workers % num_nodes == 0
num_workers_per_node = num_workers // num_nodes
server_host = "localhost"
- server_port = get_free_port()
- self.sess = None
+ server_port = _get_free_port()
+ server_exc = []
def start_server():
- self.sess = di.SocketSession(
- num_nodes, num_workers_per_node, num_groups, server_host,
server_port
- )
+ try:
+ self.sess = di.SocketSession(
+ num_nodes, num_workers_per_node, num_groups, server_host,
server_port
+ )
+ except Exception as exc: # pylint: disable=broad-except
+ server_exc.append(exc)
thread = threading.Thread(target=start_server)
thread.start()
cmd = "tvm.exec.disco_remote_socket_session"
- self.remote_nodes = []
for _i in range(num_nodes - 1):
self.remote_nodes.append(
subprocess.Popen(
[
- "python3",
+ sys.executable,
"-m",
cmd,
server_host,
@@ -102,24 +100,55 @@ class SocketSessionTester:
)
thread.join()
+ if server_exc:
+ raise server_exc[0]
+
+ # Bound at class creation: module globals may already be cleared when
+ # __del__ runs during interpreter shutdown.
+ _TIMEOUT_EXPIRED = subprocess.TimeoutExpired
def __del__(self):
- for node in self.remote_nodes:
- node.kill()
- if self.sess is not None:
- self.sess.shutdown()
- del self.sess
+ try:
+ # Shut down the session first so remote nodes can exit gracefully.
+ if self.sess is not None:
+ self.sess.shutdown()
+ finally:
+ for node in self.remote_nodes:
+ try:
+ node.wait(timeout=10)
+ except self._TIMEOUT_EXPIRED:
+ node.kill()
+ node.wait()
def create_socket_session(num_workers):
+ """Create a socket session backed by one local and one remote node.
+
+ The tester is kept alive in a module-level global so that the session
+ survives until the next call (or interpreter exit) replaces it.
+ """
global _SOCKET_SESSION_TESTER
- if _SOCKET_SESSION_TESTER is not None:
- del _SOCKET_SESSION_TESTER
+ # Rebind (not `del`) so the global stays defined if the constructor raises.
+ _SOCKET_SESSION_TESTER = None
_SOCKET_SESSION_TESTER = SocketSessionTester(num_workers)
assert _SOCKET_SESSION_TESTER.sess is not None
return _SOCKET_SESSION_TESTER.sess
+def _numpy_to_worker_0(sess: di.Session, np_array: np.array, device):
+ x_array = sess.empty(np_array.shape, "float32", device=device)
+ host_array = tvm.runtime.tensor(np_array, device=device)
+ sess.copy_to_worker_0(host_array, x_array)
+ return x_array
+
+
+def _numpy_from_worker_0(sess: di.Session, remote_array, shape, dtype):
+ host_array = tvm.runtime.empty(shape, dtype, device=tvm.cpu())
+ sess.copy_from_worker_0(host_array, remote_array)
+ sess.sync_worker_0()
+ return host_array.numpy()
+
+
_all_session_kinds = [di.ThreadedSession, di.ProcessSession,
create_socket_session]