This is an automated email from the ASF dual-hosted git repository. zhaowu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push: new afcf939 Windows Support for cpp_rpc (#4857) afcf939 is described below commit afcf9397b60ae7ccf46601cf29828992ca9d5f57 Author: jmorrill <jeremiah.morr...@gmail.com> AuthorDate: Wed Apr 15 01:49:15 2020 -0700 Windows Support for cpp_rpc (#4857) * Windows Support for cpp_rpc * Add missing patches that fix crashes under Windows * On Windows, use python to untar vs wsl * remove some CMakeLists.txt stuff * more minor CMakeLists.txt changes * Remove items from CMakeLists.txt * Minor CMakeLists.txt changes * More minor CMakeLists.txt changes * Even more minor CMakeLists.txt changes * Modify readme --- CMakeLists.txt | 8 + apps/cpp_rpc/CMakeLists.txt | 27 ++++ apps/cpp_rpc/README.md | 10 +- apps/cpp_rpc/main.cc | 95 ++++++++---- apps/cpp_rpc/rpc_env.cc | 305 ++++++++++++++++++++----------------- apps/cpp_rpc/rpc_env.h | 6 +- apps/cpp_rpc/rpc_server.cc | 250 +++++++++++++++--------------- apps/cpp_rpc/rpc_server.h | 23 ++- apps/cpp_rpc/win32_process.cc | 273 +++++++++++++++++++++++++++++++++ apps/cpp_rpc/win32_process.h | 43 ++++++ src/runtime/rpc/rpc_socket_impl.cc | 11 +- src/support/ring_buffer.h | 2 +- 12 files changed, 751 insertions(+), 302 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index cf334ff..8a559b8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -66,9 +66,14 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF) tvm_option(USE_RANDOM "Build with random support" OFF) tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF) tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF) +tvm_option(USE_CPP_RPC "Build CPP RPC" OFF) tvm_option(USE_TFLITE "Build with tflite support" OFF) tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none) +if(USE_CPP_RPC AND UNIX) + message(FATAL_ERROR "USE_CPP_RPC is only supported with WIN32. Use the Makefile for non-Windows.") +endif() + # include directories include_directories(${CMAKE_INCLUDE_PATH}) include_directories("include") @@ -309,6 +314,9 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS}) add_library(tvm_topi SHARED ${TOPI_SRCS}) add_library(tvm_runtime SHARED ${RUNTIME_SRCS}) +if(USE_CPP_RPC) + add_subdirectory("apps/cpp_rpc") +endif() if(USE_RELAY_DEBUG) message(STATUS "Building Relay in debug mode...") diff --git a/apps/cpp_rpc/CMakeLists.txt b/apps/cpp_rpc/CMakeLists.txt new file mode 100644 index 0000000..9888738 --- /dev/null +++ b/apps/cpp_rpc/CMakeLists.txt @@ -0,0 +1,27 @@ +set(TVM_RPC_SOURCES + main.cc + rpc_env.cc + rpc_server.cc +) + +if(WIN32) + list(APPEND TVM_RPC_SOURCES win32_process.cc) +endif() + +# Set output to same directory as the other TVM libs +set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) +add_executable(tvm_rpc ${TVM_RPC_SOURCES}) +set_property(TARGET tvm_rpc PROPERTY INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE) + +if(WIN32) + target_compile_definitions(tvm_rpc PUBLIC -DNOMINMAX) +endif() + +target_include_directories( + tvm_rpc + PUBLIC "../../include" + PUBLIC DLPACK_PATH + PUBLIC DMLC_PATH +) + +target_link_libraries(tvm_rpc tvm_runtime) \ No newline at end of file diff --git a/apps/cpp_rpc/README.md b/apps/cpp_rpc/README.md index 4baecaf..c826dae 100644 --- a/apps/cpp_rpc/README.md +++ b/apps/cpp_rpc/README.md @@ -18,7 +18,7 @@ # TVM RPC Server This folder contains a simple recipe to make RPC server in c++. -## Usage +## Usage (Non-Windows) - Build tvm runtime - Make the rpc executable [Makefile](Makefile). `make CXX=/path/to/cross compiler g++/ TVM_RUNTIME_DIR=/path/to/tvm runtime library directory/ OS=Linux` @@ -35,6 +35,12 @@ This folder contains a simple recipe to make RPC server in c++. ``` - Use `./tvm_rpc server` to start the RPC server +## Usage (Windows) +- Build tvm with the argument -DUSE_CPP_RPC +- Install [LLVM pre-build binaries](https://releases.llvm.org/download.html), making sure to select the option to add it to the PATH. +- Verify Python 3.6 or newer is installed and in the PATH. +- Use `<tmv_output_dir>\tvm_rpc.exe` to start the RPC server + ## How it works - The tvm runtime dll is linked along with this executable and when the RPC server starts it will load the tvm runtime library. @@ -53,4 +59,4 @@ Command line usage ``` ## Note -Currently support is only there for Linux / Android environment and proxy mode doesn't be supported currently. \ No newline at end of file +Currently support is only there for Linux / Android / Windows environment and proxy mode doesn't be supported currently. \ No newline at end of file diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc index ae66bd2..5168da3 100644 --- a/apps/cpp_rpc/main.cc +++ b/apps/cpp_rpc/main.cc @@ -21,10 +21,12 @@ * \file rpc_server.cc * \brief RPC Server for TVM. */ -#include <stdlib.h> -#include <signal.h> -#include <stdio.h> +#include <cstdlib> +#include <csignal> +#include <cstdio> +#if defined(__linux__) || defined(__ANDROID__) #include <unistd.h> +#endif #include <dmlc/logging.h> #include <iostream> #include <cstring> @@ -35,11 +37,15 @@ #include "../../src/support/socket.h" #include "rpc_server.h" +#if defined(_WIN32) +#include "win32_process.h" +#endif + using namespace std; using namespace tvm::runtime; using namespace tvm::support; -static const string kUSAGE = \ +static const string kUsage = \ "Command line usage\n" \ " server - Start the server\n" \ "--host - The hostname of the server, Default=0.0.0.0\n" \ @@ -73,13 +79,16 @@ struct RpcServerArgs { string key; string custom_addr; bool silent = false; +#if defined(WIN32) + std::string mmap_path; +#endif }; /*! * \brief PrintArgs print the contents of RpcServerArgs * \param args RpcServerArgs structure */ -void PrintArgs(struct RpcServerArgs args) { +void PrintArgs(const RpcServerArgs& args) { LOG(INFO) << "host = " << args.host; LOG(INFO) << "port = " << args.port; LOG(INFO) << "port_end = " << args.port_end; @@ -89,6 +98,7 @@ void PrintArgs(struct RpcServerArgs args) { LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False")); } +#if defined(__linux__) || defined(__ANDROID__) /*! * \brief CtrlCHandler, exits if Ctrl+C is pressed * \param s signal @@ -109,7 +119,7 @@ void HandleCtrlC() { sigIntHandler.sa_flags = 0; sigaction(SIGINT, &sigIntHandler, nullptr); } - +#endif /*! * \brief GetCmdOption Parse and find the command option. * \param argc arg counter @@ -129,7 +139,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) { } // We assume "=" is the end of option. CHECK_EQ(*option.rbegin(), '='); - cmd = arg.substr(arg.find("=") + 1); + cmd = arg.substr(arg.find('=') + 1); return cmd; } } @@ -156,41 +166,41 @@ bool ValidateTracker(string &tracker) { * \brief ParseCmdArgs parses the command line arguments. * \param argc arg counter * \param argv arg values - * \param args, the output structure which holds the parsed values + * \param args the output structure which holds the parsed values */ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { - string silent = GetCmdOption(argc, argv, "--silent", true); + const string silent = GetCmdOption(argc, argv, "--silent", true); if (!silent.empty()) { args.silent = true; // Only errors and fatal is logged dmlc::InitLogging("--minloglevel=2"); } - string host = GetCmdOption(argc, argv, "--host="); + const string host = GetCmdOption(argc, argv, "--host="); if (!host.empty()) { if (!ValidateIP(host)) { LOG(WARNING) << "Wrong host address format."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.host = host; } - string port = GetCmdOption(argc, argv, "--port="); + const string port = GetCmdOption(argc, argv, "--port="); if (!port.empty()) { if (!IsNumber(port) || stoi(port) > 65535) { LOG(WARNING) << "Wrong port number."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.port = stoi(port); } - string port_end = GetCmdOption(argc, argv, "--port_end="); + const string port_end = GetCmdOption(argc, argv, "--port_end="); if (!port_end.empty()) { if (!IsNumber(port_end) || stoi(port_end) > 65535) { LOG(WARNING) << "Wrong port_end number."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.port_end = stoi(port_end); @@ -200,26 +210,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { if (!tracker.empty()) { if (!ValidateTracker(tracker)) { LOG(WARNING) << "Wrong tracker address format."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.tracker = tracker; } - string key = GetCmdOption(argc, argv, "--key="); + const string key = GetCmdOption(argc, argv, "--key="); if (!key.empty()) { args.key = key; } - string custom_addr = GetCmdOption(argc, argv, "--custom_addr="); + const string custom_addr = GetCmdOption(argc, argv, "--custom_addr="); if (!custom_addr.empty()) { if (!ValidateIP(custom_addr)) { LOG(WARNING) << "Wrong custom address format."; - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; exit(1); } args.custom_addr = custom_addr; } +#if defined(WIN32) + const string mmap_path = GetCmdOption(argc, argv, "--child_proc="); + if(!mmap_path.empty()) { + args.mmap_path = mmap_path; + dmlc::InitLogging("--minloglevel=0"); + } +#endif + } /*! @@ -229,17 +247,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) { * \return result of operation. */ int RpcServer(int argc, char * argv[]) { - struct RpcServerArgs args; + RpcServerArgs args; /* parse the command line args */ ParseCmdArgs(argc, argv, args); PrintArgs(args); - // Ctrl+C handler LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop."; +#if defined(__linux__) || defined(__ANDROID__) + // Ctrl+C handler HandleCtrlC(); - tvm::runtime::RPCServerCreate(args.host, args.port, args.port_end, args.tracker, - args.key, args.custom_addr, args.silent); +#endif + +#if defined(WIN32) + if(!args.mmap_path.empty()) { + int ret = 0; + + try { + ChildProcSocketHandler(args.mmap_path); + } catch (const std::exception&) { + ret = -1; + } + + return ret; + } +#endif + + RPCServerCreate(args.host, args.port, args.port_end, args.tracker, + args.key, args.custom_addr, args.silent); return 0; } @@ -251,15 +286,21 @@ int RpcServer(int argc, char * argv[]) { */ int main(int argc, char * argv[]) { if (argc <= 1) { - LOG(INFO) << kUSAGE; + LOG(INFO) << kUsage; return 0; } + // Runs WSAStartup on Win32, no-op on POSIX + Socket::Startup(); +#if defined(_WIN32) + SetEnvironmentVariableA("CUDA_CACHE_DISABLE", "1"); +#endif + if (0 == strcmp(argv[1], "server")) { - RpcServer(argc, argv); - } else { - LOG(INFO) << kUSAGE; + return RpcServer(argc, argv); } + LOG(INFO) << kUsage; + return 0; } diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index 844a7af..b5dc51b 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -20,77 +20,86 @@ * \file rpc_env.cc * \brief Server environment of the RPC. */ +#include <cerrno> #include <tvm/runtime/registry.h> -#include <errno.h> -#ifndef _MSC_VER -#include <sys/stat.h> +#ifndef _WIN32 #include <dirent.h> +#include <sys/stat.h> #include <unistd.h> #else #include <Windows.h> +#include <direct.h> +namespace { + int mkdir(const char* path, int /* ignored */) { return _mkdir(path); } +} #endif +#include <cstring> #include <fstream> -#include <vector> #include <iostream> #include <string> -#include <cstring> +#include <vector> +#include <string> -#include "rpc_env.h" #include "../../src/support/util.h" #include "../../src/runtime/file_util.h" +#include "rpc_env.h" + +namespace { + std::string GenerateUntarCommand(const std::string& tar_file, const std::string& output_dir) { + std::string untar_cmd; + untar_cmd.reserve(512); +#if defined(__linux__) || defined(__ANDROID__) + untar_cmd += "tar -C "; + untar_cmd += output_dir; + untar_cmd += " -zxf "; + untar_cmd += tar_file; +#elif defined(_WIN32) + untar_cmd += "python -m tarfile -e "; + untar_cmd += tar_file; + untar_cmd += " "; + untar_cmd += output_dir; +#endif + return untar_cmd; + } + +}// Anonymous namespace namespace tvm { namespace runtime { - RPCEnv::RPCEnv() { - #if defined(__linux__) || defined(__ANDROID__) - base_ = "./rpc"; - mkdir(&base_[0], 0777); - - TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath") - .set_body([](TVMArgs args, TVMRetValue* rv) { - static RPCEnv env; - *rv = env.GetPath(args[0]); - }); + base_ = "./rpc"; + mkdir(base_.c_str(), 0777); + TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) { + static RPCEnv env; + *rv = env.GetPath(args[0]); + }); - TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module") - .set_body([](TVMArgs args, TVMRetValue *rv) { - static RPCEnv env; - std::string file_name = env.GetPath(args[0]); - *rv = Load(&file_name, ""); - LOG(INFO) << "Load module from " << file_name << " ..."; - }); - #else - LOG(FATAL) << "Only support RPC in linux environment"; - #endif + TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module").set_body([](TVMArgs args, TVMRetValue* rv) { + static RPCEnv env; + std::string file_name = env.GetPath(args[0]); + *rv = Load(&file_name, ""); + LOG(INFO) << "Load module from " << file_name << " ..."; + }); } /*! - * \brief GetPath To get the workpath from packed function - * \param name The file name + * \brief GetPath To get the work path from packed function + * \param file_name The file name * \return The full path of file. */ -std::string RPCEnv::GetPath(std::string file_name) { +std::string RPCEnv::GetPath(const std::string& file_name) const { // we assume file_name has "/" means file_name is the exact path // and does not create /.rpc/ - if (file_name.find("/") != std::string::npos) { - return file_name; - } else { - return base_ + "/" + file_name; - } + return file_name.find('/') != std::string::npos ? file_name : base_ + "/" + file_name; } /*! * \brief Remove The RPC Environment cleanup function */ -void RPCEnv::CleanUp() { - #if defined(__linux__) || defined(__ANDROID__) - CleanDir(&base_[0]); - int ret = rmdir(&base_[0]); - if (ret != 0) { - LOG(WARNING) << "Remove directory " << base_ << " failed"; - } - #else - LOG(FATAL) << "Only support RPC in linux environment"; - #endif +void RPCEnv::CleanUp() const { + CleanDir(base_); + const int ret = rmdir(base_.c_str()); + if (ret != 0) { + LOG(WARNING) << "Remove directory " << base_ << " failed"; + } } /*! @@ -98,53 +107,54 @@ void RPCEnv::CleanUp() { * \param dirname The root directory name * \return vector Files in directory. */ -std::vector<std::string> ListDir(const std::string &dirname) { +std::vector<std::string> ListDir(const std::string& dirname) { std::vector<std::string> vec; - #ifndef _MSC_VER - DIR *dp = opendir(dirname.c_str()); - if (dp == nullptr) { - int errsv = errno; - LOG(FATAL) << "ListDir " << dirname <<" error: " << strerror(errsv); - } - dirent *d; - while ((d = readdir(dp)) != nullptr) { - std::string filename = d->d_name; - if (filename != "." && filename != "..") { - std::string f = dirname; - if (f[f.length() - 1] != '/') { - f += '/'; - } - f += d->d_name; - vec.push_back(f); +#ifndef _WIN32 + DIR* dp = opendir(dirname.c_str()); + if (dp == nullptr) { + int errsv = errno; + LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); + } + dirent* d; + while ((d = readdir(dp)) != nullptr) { + std::string filename = d->d_name; + if (filename != "." && filename != "..") { + std::string f = dirname; + if (f[f.length() - 1] != '/') { + f += '/'; } + f += d->d_name; + vec.push_back(f); } - closedir(dp); - #else - WIN32_FIND_DATA fd; - std::string pattern = dirname + "/*"; - HANDLE handle = FindFirstFile(pattern.c_str(), &fd); - if (handle == INVALID_HANDLE_VALUE) { - int errsv = GetLastError(); - LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); - } - do { - if (fd.cFileName != "." && fd.cFileName != "..") { - std::string f = dirname; - char clast = f[f.length() - 1]; - if (f == ".") { - f = fd.cFileName; - } else if (clast != '/' && clast != '\\') { - f += '/'; - f += fd.cFileName; - } - vec.push_back(f); + } + closedir(dp); +#elif defined(_WIN32) + WIN32_FIND_DATAA fd; + const std::string pattern = dirname + "/*"; + HANDLE handle = FindFirstFileA(pattern.c_str(), &fd); + if (handle == INVALID_HANDLE_VALUE) { + const int errsv = GetLastError(); + LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv); + } + do { + std::string filename = fd.cFileName; + if (filename != "." && filename != "..") { + std::string f = dirname; + if (f[f.length() - 1] != '/') { + f += '/'; } - } while (FindNextFile(handle, &fd)); - FindClose(handle); - #endif + f += filename; + vec.push_back(f); + } + } while (FindNextFileA(handle, &fd)); + FindClose(handle); +#else + LOG(FATAL) << "Operating system not supported"; +#endif return vec; } +#if defined(__linux__) || defined(__ANDROID__) /*! * \brief LinuxShared Creates a linux shared library * \param output The output file name @@ -152,9 +162,9 @@ std::vector<std::string> ListDir(const std::string &dirname) { * \param options The compiler options * \param cc The compiler */ -void LinuxShared(const std::string output, +void LinuxShared(const std::string output, const std::vector<std::string> &files, - std::string options = "", + std::string options = "", std::string cc = "g++") { std::string cmd = cc; cmd += " -shared -fPIC "; @@ -169,18 +179,48 @@ void LinuxShared(const std::string output, LOG(FATAL) << err_msg; } } +#endif + +#ifdef _WIN32 +/*! + * \brief WindowsShared Creates a Windows shared library + * \param output The output file name + * \param files The files for building + * \param options The compiler options + * \param cc The compiler + */ +void WindowsShared(const std::string& output, + const std::vector<std::string>& files, + const std::string& options = "", + const std::string& cc = "clang") { + std::string cmd = cc; + cmd += " -O2 -flto=full -fuse-ld=lld-link -Wl,/EXPORT:__tvm_main__ -shared "; + cmd += " -o " + output; + for (const auto& file : files) { + cmd += " " + file; + } + cmd += " " + options; + std::string err_msg; + const auto executed_status = support::Execute(cmd, &err_msg); + if (executed_status) { + LOG(FATAL) << err_msg; + } +} +#endif /*! * \brief CreateShared Creates a shared library * \param output The output file name * \param files The files for building */ -void CreateShared(const std::string output, const std::vector<std::string> &files) { - #if defined(__linux__) || defined(__ANDROID__) - LinuxShared(output, files); - #else - LOG(FATAL) << "Do not support creating shared library"; - #endif +void CreateShared(const std::string& output, const std::vector<std::string>& files) { +#if defined(__linux__) || defined(__ANDROID__) + LinuxShared(output, files); +#elif defined(_WIN32) + WindowsShared(output, files); +#else + LOG(FATAL) << "Operating system not supported"; +#endif } /*! @@ -193,61 +233,52 @@ void CreateShared(const std::string output, const std::vector<std::string> &file * \param fmt The format of file * \return Module The loaded module */ -Module Load(std::string *fileIn, const std::string fmt) { - std::string file = *fileIn; - if (support::EndsWith(file, ".so")) { - return Module::LoadFromFile(file, fmt); +Module Load(std::string *fileIn, const std::string& fmt) { + const std::string& file = *fileIn; + if (support::EndsWith(file, ".so") || support::EndsWith(file, ".dll")) { + return Module::LoadFromFile(file, fmt); } - #if defined(__linux__) || defined(__ANDROID__) - std::string file_name = file + ".so"; - if (support::EndsWith(file, ".o")) { - std::vector<std::string> files; - files.push_back(file); - CreateShared(file_name, files); - } else if (support::EndsWith(file, ".tar")) { - std::string tmp_dir = "./rpc/tmp/"; - mkdir(&tmp_dir[0], 0777); - std::string cmd = "tar -C " + tmp_dir + " -zxf " + file; - std::string err_msg; - int executed_status = support::Execute(cmd, &err_msg); - if (executed_status) { - LOG(FATAL) << err_msg; - } - CreateShared(file_name, ListDir(tmp_dir)); - CleanDir(tmp_dir); - rmdir(&tmp_dir[0]); - } else { - file_name = file; + std::string file_name = file + ".so"; + if (support::EndsWith(file, ".o")) { + std::vector<std::string> files; + files.push_back(file); + CreateShared(file_name, files); + } else if (support::EndsWith(file, ".tar")) { + const std::string tmp_dir = "./rpc/tmp/"; + mkdir(tmp_dir.c_str(), 0777); + + const std::string cmd = GenerateUntarCommand(file, tmp_dir); + + std::string err_msg; + const int executed_status = support::Execute(cmd, &err_msg); + if (executed_status) { + LOG(FATAL) << err_msg; } - *fileIn = file_name; - return Module::LoadFromFile(file_name, fmt); - #else - LOG(FATAL) << "Do not support creating shared library"; - #endif + CreateShared(file_name, ListDir(tmp_dir)); + CleanDir(tmp_dir); + (void)rmdir(tmp_dir.c_str()); + } else { + file_name = file; + } + *fileIn = file_name; + return Module::LoadFromFile(file_name, fmt); } /*! * \brief CleanDir Removes the files from the directory * \param dirname The name of the directory */ -void CleanDir(const std::string &dirname) { - #if defined(__linux__) || defined(__ANDROID__) - DIR *dp = opendir(dirname.c_str()); - dirent *d; - while ((d = readdir(dp)) != nullptr) { - std::string filename = d->d_name; - if (filename != "." && filename != "..") { - filename = dirname + "/" + d->d_name; - int ret = std::remove(&filename[0]); - if (ret != 0) { - LOG(WARNING) << "Remove file " << filename << " failed"; - } - } +void CleanDir(const std::string& dirname) { + auto files = ListDir(dirname); + for (const auto& filename : files) { + std::string file_path = dirname + "/"; + file_path += filename; + const int ret = std::remove(filename.c_str()); + if (ret != 0) { + LOG(WARNING) << "Remove file " << filename << " failed"; } - #else - LOG(FATAL) << "Only support RPC in linux environment"; - #endif + } } } // namespace runtime diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h index 82409ba..d046f6e 100644 --- a/apps/cpp_rpc/rpc_env.h +++ b/apps/cpp_rpc/rpc_env.h @@ -40,7 +40,7 @@ namespace runtime { * \param file The format of file * \return Module The loaded module */ -Module Load(std::string *path, const std::string fmt = ""); +Module Load(std::string *path, const std::string& fmt = ""); /*! * \brief CleanDir Removes the files from the directory @@ -62,11 +62,11 @@ struct RPCEnv { * \param name The file name * \return The full path of file. */ - std::string GetPath(std::string file_name); + std::string GetPath(const std::string& file_name) const; /*! * \brief The RPC Environment cleanup function */ - void CleanUp(); + void CleanUp() const; private: /*! diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index 1a29421..ea4ab00 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -22,24 +22,27 @@ * \brief RPC Server implementation. */ #include <tvm/runtime/registry.h> - #if defined(__linux__) || defined(__ANDROID__) #include <sys/select.h> #include <sys/wait.h> #endif -#include <set> -#include <iostream> -#include <future> -#include <thread> #include <chrono> +#include <future> +#include <iostream> +#include <set> #include <string> -#include "rpc_server.h" -#include "rpc_env.h" -#include "rpc_tracker_client.h" +#include "../../src/support/socket.h" #include "../../src/runtime/rpc/rpc_session.h" #include "../../src/runtime/rpc/rpc_socket_impl.h" -#include "../../src/support/socket.h" +#include "rpc_env.h" +#include "rpc_server.h" +#include "rpc_tracker_client.h" +#if defined(_WIN32) +#include "win32_process.h" +#endif + +using namespace std::chrono; namespace tvm { namespace runtime { @@ -49,7 +52,7 @@ namespace runtime { * \param status status value */ #if defined(__linux__) || defined(__ANDROID__) -static pid_t waitPidEintr(int *status) { +static pid_t waitPidEintr(int* status) { pid_t pid = 0; while ((pid = waitpid(-1, status, 0)) == -1) { if (errno == EINTR) { @@ -76,34 +79,32 @@ class RPCServer { public: /*! * \brief Constructor. - */ - RPCServer(const std::string &host, - int port, - int port_end, - const std::string &tracker_addr, - const std::string &key, - const std::string &custom_addr) { - // Init the values - host_ = host; - port_ = port; - port_end_ = port_end; - tracker_addr_ = tracker_addr; - key_ = key; - custom_addr_ = custom_addr; + */ + RPCServer(std::string host, int port, int port_end, std::string tracker_addr, + std::string key, std::string custom_addr) : + host_(std::move(host)), port_(port), my_port_(0), port_end_(port_end), + tracker_addr_(std::move(tracker_addr)), key_(std::move(key)), + custom_addr_(std::move(custom_addr)) + { + } /*! * \brief Destructor. - */ + */ ~RPCServer() { - // Free the resources - tracker_sock_.Close(); - listen_sock_.Close(); + try { + // Free the resources + tracker_sock_.Close(); + listen_sock_.Close(); + } catch(...) { + + } } /*! * \brief Start Creates the RPC listen process and execution. - */ + */ void Start() { listen_sock_.Create(); my_port_ = listen_sock_.TryBindHost(host_, port_, port_end_); @@ -130,102 +131,98 @@ class RPCServer { tracker.TryConnect(); // step 2: wait for in-coming connections AcceptConnection(&tracker, &conn, &addr, &opts); - } - catch (const char* msg) { + } catch (const char* msg) { LOG(WARNING) << "Socket exception: " << msg; // close tracker resource tracker.Close(); continue; - } - catch (std::exception& e) { - // Other errors + } catch (const std::exception& e) { + // close tracker resource + tracker.Close(); LOG(WARNING) << "Exception standard: " << e.what(); continue; } int timeout = GetTimeOutFromOpts(opts); - #if defined(__linux__) || defined(__ANDROID__) - // step 3: serving - if (timeout != 0) { - const pid_t timer_pid = fork(); - if (timer_pid == 0) { - // Timer process - sleep(timeout); - exit(0); - } +#if defined(__linux__) || defined(__ANDROID__) + // step 3: serving + if (timeout != 0) { + const pid_t timer_pid = fork(); + if (timer_pid == 0) { + // Timer process + sleep(timeout); + exit(0); + } - const pid_t worker_pid = fork(); - if (worker_pid == 0) { - // Worker process - ServerLoopProc(conn, addr); - exit(0); - } + const pid_t worker_pid = fork(); + if (worker_pid == 0) { + // Worker process + ServerLoopProc(conn, addr); + exit(0); + } - int status = 0; - const pid_t finished_first = waitPidEintr(&status); - if (finished_first == timer_pid) { - kill(worker_pid, SIGKILL); - } else if (finished_first == worker_pid) { - kill(timer_pid, SIGKILL); - } else { - LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; - } + int status = 0; + const pid_t finished_first = waitPidEintr(&status); + if (finished_first == timer_pid) { + kill(worker_pid, SIGKILL); + } else if (finished_first == worker_pid) { + kill(timer_pid, SIGKILL); + } else { + LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue."; + } - int status_second = 0; - waitPidEintr(&status_second); + int status_second = 0; + waitPidEintr(&status_second); - // Logging. - if (finished_first == timer_pid) { - LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout - << "), Process status = " << status_second; - } else if (finished_first == worker_pid) { - LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second; - } - } else { - auto pid = fork(); - if (pid == 0) { - ServerLoopProc(conn, addr); - exit(0); - } - // Wait for the result - int status = 0; - wait(&status); - LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status; + // Logging. + if (finished_first == timer_pid) { + LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout + << "), Process status = " << status_second; + } else if (finished_first == worker_pid) { + LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second; } - #else - // step 3: serving - std::future<void> proc(std::async(std::launch::async, - &RPCServer::ServerLoopProc, this, conn, addr)); - // wait until server process finish or timeout - if (timeout != 0) { - // Autoterminate after timeout - proc.wait_for(std::chrono::seconds(timeout)); - } else { - // Wait for the result - proc.get(); + } else { + auto pid = fork(); + if (pid == 0) { + ServerLoopProc(conn, addr); + exit(0); } - #endif + // Wait for the result + int status = 0; + wait(&status); + LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status; + } +#elif defined(WIN32) + auto start_time = high_resolution_clock::now(); + try { + SpawnRPCChild(conn.sockfd, seconds(timeout)); + } catch (const std::exception&) { + + } + auto dur = high_resolution_clock::now() - start_time; + + LOG(INFO) << "Serve Time " << duration_cast<milliseconds>(dur).count() << "ms"; +#endif // close from our side. LOG(INFO) << "Socket Connection Closed"; conn.Close(); } } - /*! * \brief AcceptConnection Accepts the RPC Server connection. * \param tracker Tracker details. - * \param conn New connection information. + * \param conn_sock New connection information. * \param addr New connection address information. * \param opts Parsed options for socket * \param ping_period Timeout for select call waiting */ - void AcceptConnection(TrackerClient* tracker, + void AcceptConnection(TrackerClient* tracker, support::TCPSocket* conn_sock, - support::SockAddr* addr, - std::string* opts, + support::SockAddr* addr, + std::string* opts, int ping_period = 2) { - std::set <std::string> old_keyset; + std::set<std::string> old_keyset; std::string matchkey; // Report resource to tracker and get key @@ -236,7 +233,7 @@ class RPCServer { support::TCPSocket conn = listen_sock_.Accept(addr); int code = kRPCMagic; - CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); + CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code)); if (code != kRPCMagic) { conn.Close(); LOG(FATAL) << "Client connected is not TVM RPC server"; @@ -265,15 +262,15 @@ class RPCServer { std::string arg0; ssin >> arg0; if (arg0 != expect_header) { - code = kRPCMismatch; - CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); - conn.Close(); - LOG(WARNING) << "Mismatch key from" << addr->AsString(); - continue; + code = kRPCMismatch; + CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); + conn.Close(); + LOG(WARNING) << "Mismatch key from" << addr->AsString(); + continue; } else { code = kRPCSuccess; CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code)); - keylen = server_key.length(); + keylen = int(server_key.length()); CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen)); CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen); LOG(INFO) << "Connection success " << addr->AsString(); @@ -289,25 +286,23 @@ class RPCServer { * \param sock The socket information * \param addr The socket address information */ - void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr) { - // Server loop - auto env = RPCEnv(); - RPCServerLoop(sock.sockfd); - LOG(INFO) << "Finish serving " << addr.AsString(); - env.CleanUp(); + static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr) { + // Server loop + const auto env = RPCEnv(); + RPCServerLoop(int(sock.sockfd)); + LOG(INFO) << "Finish serving " << addr.AsString(); + env.CleanUp(); } /*! * \brief GetTimeOutFromOpts Parse and get the timeout option. * \param opts The option string - * \param timeout value after parsing. */ - int GetTimeOutFromOpts(std::string opts) { - std::string cmd; - std::string option = "-timeout="; + int GetTimeOutFromOpts(const std::string& opts) const { + const std::string option = "-timeout="; if (opts.find(option) == 0) { - cmd = opts.substr(opts.find_last_of(option) + 1); + const std::string cmd = opts.substr(opts.find_last_of(option) + 1); CHECK(support::IsNumber(cmd)) << "Timeout is not valid"; return std::stoi(cmd); } @@ -325,29 +320,40 @@ class RPCServer { support::TCPSocket tracker_sock_; }; +#if defined(WIN32) +/*! +* \brief ServerLoopFromChild The Server loop process. +* \param socket The socket information +*/ +void ServerLoopFromChild(SOCKET socket) { + // Server loop + tvm::support::TCPSocket sock(socket); + const auto env = RPCEnv(); + RPCServerLoop(int(sock.sockfd)); + + sock.Close(); + env.CleanUp(); +} +#endif + /*! * \brief RPCServerCreate Creates the RPC Server. * \param host The hostname of the server, Default=0.0.0.0 * \param port The port of the RPC, Default=9090 * \param port_end The end search port of the RPC, Default=9199 - * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" + * \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" * \param key The key used to identify the device type in tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param silent Whether run in silent mode. Default=True */ -void RPCServerCreate(std::string host, - int port, - int port_end, - std::string tracker_addr, - std::string key, - std::string custom_addr, - bool silent) { +void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr, + std::string key, std::string custom_addr, bool silent) { if (silent) { // Only errors and fatal is logged dmlc::InitLogging("--minloglevel=2"); } // Start the rpc server - RPCServer rpc(host, port, port_end, tracker_addr, key, custom_addr); + RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key), std::move(custom_addr)); rpc.Start(); } diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h index 205182e..db7c89d 100644 --- a/apps/cpp_rpc/rpc_server.h +++ b/apps/cpp_rpc/rpc_server.h @@ -30,6 +30,15 @@ namespace tvm { namespace runtime { +#if defined(WIN32) +/*! + * \brief ServerLoopFromChild The Server loop process. + * \param sock The socket information + * \param addr The socket address information + */ +void ServerLoopFromChild(SOCKET socket); +#endif + /*! * \brief RPCServerCreate Creates the RPC Server. * \param host The hostname of the server, Default=0.0.0.0 @@ -40,13 +49,13 @@ namespace runtime { * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" * \param silent Whether run in silent mode. Default=True */ -TVM_DLL void RPCServerCreate(std::string host = "", - int port = 9090, - int port_end = 9099, - std::string tracker_addr = "", - std::string key = "", - std::string custom_addr = "", - bool silent = true); +void RPCServerCreate(std::string host = "", + int port = 9090, + int port_end = 9099, + std::string tracker_addr = "", + std::string key = "", + std::string custom_addr = "", + bool silent = true); } // namespace runtime } // namespace tvm #endif // TVM_APPS_CPP_RPC_SERVER_H_ diff --git a/apps/cpp_rpc/win32_process.cc b/apps/cpp_rpc/win32_process.cc new file mode 100644 index 0000000..c6c72d7 --- /dev/null +++ b/apps/cpp_rpc/win32_process.cc @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#include <winsock2.h> +#include <ws2tcpip.h> +#include <cstdio> +#include <memory> +#include <conio.h> +#include <string> +#include <stdexcept> +#include <dmlc/logging.h> +#include "win32_process.h" +#include "rpc_server.h" + +using namespace std::chrono; +using namespace tvm::runtime; + +namespace { +// The prefix path for the memory mapped file used to store IPC information +const std::string kMemoryMapPrefix = "/MAPPED_FILE/TVM_RPC"; +// Used to construct unique names for named resources in the parent process +const std::string kParent = "parent"; +// Used to construct unique names for named resources in the child process +const std::string kChild = "child"; +// The timeout of the WIN32 events, in the parent and the child +const milliseconds kEventTimeout(2000); + +// Used to create unique WIN32 mmap paths and event names +int child_counter_ = 0; + +/*! + * \brief HandleDeleter Deleter for UniqueHandle smart pointer + * \param handle The WIN32 HANDLE to manage + */ +struct HandleDeleter { + void operator()(HANDLE handle) const { + if (handle != INVALID_HANDLE_VALUE && handle != nullptr) { + CloseHandle(handle); + } + } +}; + +/*! + * \brief UniqueHandle Smart pointer to manage a WIN32 HANDLE + */ +using UniqueHandle = std::unique_ptr<void, HandleDeleter>; + +/*! + * \brief MakeUniqueHandle Helper method to construct a UniqueHandle + * \param handle The WIN32 HANDLE to manage + */ +UniqueHandle MakeUniqueHandle(HANDLE handle) { + if (handle == INVALID_HANDLE_VALUE || handle == nullptr) { + return nullptr; + } + + return UniqueHandle(handle); +} + +/*! + * \brief GetSocket Gets the socket info from the parent process and duplicates the socket + * \param mmap_path The path to the memory mapped info set by the parent + */ +SOCKET GetSocket(const std::string& mmap_path) { + WSAPROTOCOL_INFO protocol_info; + + const std::string parent_event_name = mmap_path + kParent; + const std::string child_event_name = mmap_path + kChild; + + // Open the events + UniqueHandle parent_file_mapping_event; + if ((parent_file_mapping_event = MakeUniqueHandle(OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) { + LOG(FATAL) << "OpenEvent() failed: " << GetLastError(); + } + + UniqueHandle child_file_mapping_event; + if ((child_file_mapping_event = MakeUniqueHandle(OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) { + LOG(FATAL) << "OpenEvent() failed: " << GetLastError(); + } + + // Wait for the parent to set the event, notifying WSAPROTOCOL_INFO is ready to be read + if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) { + LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError(); + } + + const UniqueHandle file_map = MakeUniqueHandle(OpenFileMappingA(FILE_MAP_READ | FILE_MAP_WRITE, + false, + mmap_path.c_str())); + if (!file_map) { + LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); + } + + void* map_view = MapViewOfFile(file_map.get(), + FILE_MAP_READ | FILE_MAP_WRITE, + 0, 0, 0); + + SOCKET sock_duplicated = INVALID_SOCKET; + + if (map_view != nullptr) { + memcpy(&protocol_info, map_view, sizeof(WSAPROTOCOL_INFO)); + UnmapViewOfFile(map_view); + + // Creates the duplicate socket, that was created in the parent + sock_duplicated = WSASocket(FROM_PROTOCOL_INFO, + FROM_PROTOCOL_INFO, + FROM_PROTOCOL_INFO, + &protocol_info, + 0, + 0); + + // Let the parent know we are finished dupicating the socket + SetEvent(child_file_mapping_event.get()); + } else { + LOG(FATAL) << "MapViewOfFile() failed: " << GetLastError(); + } + + return sock_duplicated; +} +}// Anonymous namespace + +namespace tvm { +namespace runtime { +/*! + * \brief SpawnRPCChild Spawns a child process with a given timeout to run + * \param fd The client socket to duplicate in the child + * \param timeout The time in seconds to wait for the child to complete before termination + */ +void SpawnRPCChild(SOCKET fd, seconds timeout) { + STARTUPINFOA startup_info; + + memset(&startup_info, 0, sizeof(startup_info)); + startup_info.cb = sizeof(startup_info); + + std::string file_map_path = kMemoryMapPrefix + std::to_string(child_counter_++); + + const std::string parent_event_name = file_map_path + kParent; + const std::string child_event_name = file_map_path + kChild; + + // Create an event to let the child know the socket info was set to the mmap file + UniqueHandle parent_file_mapping_event; + if ((parent_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) { + LOG(FATAL) << "CreateEvent for parent file mapping failed"; + } + + UniqueHandle child_file_mapping_event; + // An event to let the parent know the socket info was read from the mmap file + if ((child_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) { + LOG(FATAL) << "CreateEvent for child file mapping failed"; + } + + char current_executable[MAX_PATH]; + + // Get the full path of the current executable + GetModuleFileNameA(nullptr, current_executable, MAX_PATH); + + std::string child_command_line = current_executable; + child_command_line += " server --child_proc="; + child_command_line += file_map_path; + + // CreateProcessA requires a non const char*, so we copy our std::string + std::unique_ptr<char[]> command_line_ptr(new char[child_command_line.size() + 1]); + strcpy(command_line_ptr.get(), child_command_line.c_str()); + + PROCESS_INFORMATION child_process_info; + if (CreateProcessA(nullptr, + command_line_ptr.get(), + nullptr, + nullptr, + false, + CREATE_NO_WINDOW, + nullptr, + nullptr, + &startup_info, + &child_process_info)) { + // Child process and thread handles must be closed, so wrapped in RAII + auto child_process_handle = MakeUniqueHandle(child_process_info.hProcess); + auto child_process_thread_handle = MakeUniqueHandle(child_process_info.hThread); + + WSAPROTOCOL_INFO protocol_info; + // Get info needed to duplicate the socket + if (WSADuplicateSocket(fd, + child_process_info.dwProcessId, + &protocol_info) == SOCKET_ERROR) { + LOG(FATAL) << "WSADuplicateSocket(): failed. Error =" << WSAGetLastError(); + } + + // Create a mmap file to store the info needed for duplicating the SOCKET in the child proc + UniqueHandle file_map = MakeUniqueHandle(CreateFileMappingA(INVALID_HANDLE_VALUE, + nullptr, + PAGE_READWRITE, + 0, + sizeof(WSAPROTOCOL_INFO), + file_map_path.c_str())); + if (!file_map) { + LOG(INFO) << "CreateFileMapping() failed: " << GetLastError(); + } + + if (GetLastError() == ERROR_ALREADY_EXISTS) { + LOG(FATAL) << "CreateFileMapping(): mapping file already exists"; + } else { + void* map_view = MapViewOfFile(file_map.get(), FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0); + + if (map_view != nullptr) { + memcpy(map_view, &protocol_info, sizeof(WSAPROTOCOL_INFO)); + UnmapViewOfFile(map_view); + + // Let child proc know the mmap file is ready to be read + SetEvent(parent_file_mapping_event.get()); + + // Wait for the child to finish reading mmap file + if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) { + TerminateProcess(child_process_handle.get(), 0); + LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child process."; + } + } else { + TerminateProcess(child_process_handle.get(), 0); + LOG(FATAL) << "MapViewOfFile() failed: " << GetLastError(); + } + } + + const DWORD process_timeout = timeout.count() + ? uint32_t(duration_cast<milliseconds>(timeout).count()) + : INFINITE; + + // Wait for child process to exit, or hit configured timeout + if (WaitForSingleObject(child_process_handle.get(), process_timeout) != WAIT_OBJECT_0) { + LOG(INFO) << "Child process timeout. Terminating."; + TerminateProcess(child_process_handle.get(), 0); + } + } else { + LOG(INFO) << "Create child process failed: " << GetLastError(); + } +} +/*! + * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket + * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent + */ +void ChildProcSocketHandler(const std::string& mmap_path) { + SOCKET socket; + + // Set high thread priority to avoid the thread scheduler from + // interfering with any measurements in the RPC server. + SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_TIME_CRITICAL); + + if ((socket = GetSocket(mmap_path)) != INVALID_SOCKET) { + tvm::runtime::ServerLoopFromChild(socket); + } + else { + LOG(FATAL) << "GetSocket() failed"; + } + +} +} // namespace runtime +} // namespace tvm \ No newline at end of file diff --git a/apps/cpp_rpc/win32_process.h b/apps/cpp_rpc/win32_process.h new file mode 100644 index 0000000..7d1a276 --- /dev/null +++ b/apps/cpp_rpc/win32_process.h @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + /*! + * \file win32_process.h + * \brief Win32 process code to mimic a POSIX fork() + */ +#ifndef TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ +#define TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ +#include <chrono> +#include <string> +namespace tvm { +namespace runtime { +/*! + * \brief SpawnRPCChild Spawns a child process with a given timeout to run + * \param fd The client socket to duplicate in the child + * \param timeout The time in seconds to wait for the child to complete before termination + */ +void SpawnRPCChild(SOCKET fd, std::chrono::seconds timeout); +/*! + * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket + * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent + */ +void ChildProcSocketHandler(const std::string& mmap_path); +} // namespace runtime +} // namespace tvm +#endif // TVM_APPS_CPP_RPC_WIN32_PROCESS_H_ \ No newline at end of file diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 6b4e341..642fbb8 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -34,8 +34,12 @@ class SockChannel final : public RPCChannel { explicit SockChannel(support::TCPSocket sock) : sock_(sock) {} ~SockChannel() { - if (!sock_.BadSocket()) { - sock_.Close(); + try { + // BadSocket can throw + if (!sock_.BadSocket()) { + sock_.Close(); + } + } catch (...) { } } size_t Send(const void* data, size_t size) final { @@ -100,7 +104,8 @@ Module RPCClientConnect(std::string url, int port, std::string key) { return CreateRPCModule(RPCConnect(url, port, "client:" + key)); } -void RPCServerLoop(int sockfd) { +// TVM_DLL needed for MSVC +TVM_DLL void RPCServerLoop(int sockfd) { support::TCPSocket sock( static_cast<support::TCPSocket::SockType>(sockfd)); RPCSession::Create( diff --git a/src/support/ring_buffer.h b/src/support/ring_buffer.h index 7700a96..e6e3b04 100644 --- a/src/support/ring_buffer.h +++ b/src/support/ring_buffer.h @@ -63,7 +63,7 @@ class RingBuffer { size_t ncopy = head_ptr_ + bytes_available_ - old_size; memcpy(&ring_[0] + old_size, &ring_[0], ncopy); } - } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity) { + } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity && bytes_available_ > 0) { // shrink too large temporary buffer to avoid out of memory on some embedded devices size_t old_bytes = bytes_available_;