This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new df25378  add shared storage in windows (#8967)
df25378 is described below

commit df25378ba9b3a5703c515cc941e15f474f3b1d11
Author: Hu Shiwen <yajiedes...@gmail.com>
AuthorDate: Tue Dec 19 03:06:58 2017 +0800

    add shared storage in windows (#8967)
    
    * add shared storage in windows
    
    * fix
    
    * lint
    
    * fix
    
    * fix
    
    * fix
    
    * fix process.h
---
 amalgamation/amalgamation.py             |  1 +
 python/mxnet/gluon/data/dataloader.py    | 10 ++---
 src/storage/cpu_shared_storage_manager.h | 75 +++++++++++++++++++++++++++++---
 3 files changed, 73 insertions(+), 13 deletions(-)

diff --git a/amalgamation/amalgamation.py b/amalgamation/amalgamation.py
index b378817..9419898 100644
--- a/amalgamation/amalgamation.py
+++ b/amalgamation/amalgamation.py
@@ -43,6 +43,7 @@ if platform.system() != 'Darwin':
 
 if platform.system() != 'Windows':
   blacklist.append('windows.h')
+  blacklist.append('process.h')
 
 def pprint(lst):
     for item in lst:
diff --git a/python/mxnet/gluon/data/dataloader.py 
b/python/mxnet/gluon/data/dataloader.py
index beb228e..8dea59f 100644
--- a/python/mxnet/gluon/data/dataloader.py
+++ b/python/mxnet/gluon/data/dataloader.py
@@ -25,9 +25,7 @@ import multiprocessing.queues
 from multiprocessing.reduction import ForkingPickler
 import pickle
 import io
-import os
 import sys
-import warnings
 import numpy as np
 
 from . import sampler as _sampler
@@ -52,7 +50,7 @@ class ConnectionWrapper(object):
     NDArray via shared memory."""
 
     def __init__(self, conn):
-        self.conn = conn
+        self._conn = conn
 
     def send(self, obj):
         """Send object"""
@@ -67,7 +65,8 @@ class ConnectionWrapper(object):
 
     def __getattr__(self, name):
         """Emmulate conn"""
-        return getattr(self.conn, name)
+        attr = self.__dict__.get('_conn', None)
+        return getattr(attr, name)
 
 
 class Queue(multiprocessing.queues.Queue):
@@ -188,9 +187,6 @@ class DataLoader(object):
                              "not be specified if batch_sampler is specified.")
 
         self._batch_sampler = batch_sampler
-        if num_workers > 0 and os.name == 'nt':
-            warnings.warn("DataLoader does not support num_workers > 0 on 
Windows yet.")
-            num_workers = 0
         self._num_workers = num_workers
         if batchify_fn is None:
             if num_workers > 0:
diff --git a/src/storage/cpu_shared_storage_manager.h 
b/src/storage/cpu_shared_storage_manager.h
index 9f0f2a3..98f706b 100644
--- a/src/storage/cpu_shared_storage_manager.h
+++ b/src/storage/cpu_shared_storage_manager.h
@@ -31,6 +31,9 @@
 #include <unistd.h>
 #include <sys/types.h>
 #include <sys/stat.h>
+#else
+#include <Windows.h>
+#include <process.h>
 #endif  // _WIN32
 
 #include <unordered_map>
@@ -64,6 +67,9 @@ class CPUSharedStorageManager final : public StorageManager {
     for (const auto& kv : pool_) {
       FreeImpl(kv.second);
     }
+#ifdef _WIN32
+    CheckAndRealFree();
+#endif
   }
 
   void Alloc(Storage::Handle* handle) override;
@@ -91,11 +97,18 @@ class CPUSharedStorageManager final : public StorageManager 
{
  private:
   static constexpr size_t alignment_ = 16;
 
-  std::mutex mutex_;
+  std::recursive_mutex mutex_;
   std::mt19937 rand_gen_;
   std::unordered_map<void*, Storage::Handle> pool_;
+#ifdef _WIN32
+  std::unordered_map<void*, Storage::Handle> is_free_;
+  std::unordered_map<void*, HANDLE> map_handle_map_;
+#endif
 
   void FreeImpl(const Storage::Handle& handle);
+#ifdef _WIN32
+  void CheckAndRealFree();
+#endif
 
   std::string SharedHandleToString(int shared_pid, int shared_id) {
     std::stringstream name;
@@ -106,14 +119,44 @@ class CPUSharedStorageManager final : public 
StorageManager {
 };  // class CPUSharedStorageManager
 
 void CPUSharedStorageManager::Alloc(Storage::Handle* handle) {
-  std::lock_guard<std::mutex> lock(mutex_);
+  std::lock_guard<std::recursive_mutex> lock(mutex_);
   std::uniform_int_distribution<> dis(0, std::numeric_limits<int>::max());
   int fid = -1;
   bool is_new = false;
   size_t size = handle->size + alignment_;
-  void* ptr = nullptr;
-#ifdef _WIN32
-  LOG(FATAL) << "Shared memory is not supported on Windows yet.";
+  void *ptr = nullptr;
+  #ifdef _WIN32
+  CheckAndRealFree();
+  HANDLE map_handle = nullptr;
+  uint32_t error = 0;
+  if (handle->shared_id == -1 && handle->shared_pid == -1) {
+    is_new = true;
+    handle->shared_pid = _getpid();
+    for (int i = 0; i < 10; ++i) {
+      handle->shared_id = dis(rand_gen_);
+      auto filename = SharedHandleToString(handle->shared_pid, 
handle->shared_id);
+      map_handle = CreateFileMapping(INVALID_HANDLE_VALUE,
+                                     NULL, PAGE_READWRITE, 0, size, 
filename.c_str());
+      if ((error = GetLastError()) == ERROR_SUCCESS) {
+        break;;
+      }
+    }
+  } else {
+    auto filename = SharedHandleToString(handle->shared_pid, 
handle->shared_id);
+    map_handle = OpenFileMapping(FILE_MAP_READ | FILE_MAP_WRITE,
+                                 FALSE, filename.c_str());
+    error = GetLastError();
+  }
+
+  if (error != ERROR_SUCCESS && map_handle == nullptr) {
+    LOG(FATAL) << "Failed to open shared memory. CreateFileMapping failed with 
error "
+               << error;
+  }
+
+  ptr = MapViewOfFile(map_handle, FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0);
+  CHECK_NE(ptr, (void *)0)
+      << "Failed to map shared memory. MapViewOfFile failed with error " << 
GetLastError();
+  map_handle_map_[ptr] = map_handle;
 #else
   if (handle->shared_id == -1 && handle->shared_pid == -1) {
     is_new = true;
@@ -153,7 +196,7 @@ void CPUSharedStorageManager::FreeImpl(const 
Storage::Handle& handle) {
   int count = DecrementRefCount(handle);
   CHECK_GE(count, 0);
 #ifdef _WIN32
-  LOG(FATAL) << "Shared memory is not supported on Windows yet.";
+  is_free_[handle.dptr] = handle;
 #else
   CHECK_EQ(munmap(static_cast<char*>(handle.dptr) - alignment_,
                   handle.size + alignment_), 0)
@@ -169,6 +212,26 @@ void CPUSharedStorageManager::FreeImpl(const 
Storage::Handle& handle) {
 #endif  // _WIN32
 }
 
+#ifdef _WIN32
+inline void CPUSharedStorageManager::CheckAndRealFree() {
+  std::lock_guard<std::recursive_mutex> lock(mutex_);
+  for (auto it = std::begin(is_free_); it != std::end(is_free_);) {
+    void* ptr = static_cast<char*>(it->second.dptr) - alignment_;
+    std::atomic<int>* counter = reinterpret_cast<std::atomic<int>*>(
+      static_cast<char*>(it->second.dptr) - alignment_);
+    if ((*counter) == 0) {
+      CHECK_NE(UnmapViewOfFile(ptr), 0)
+        << "Failed to UnmapViewOfFile shared memory ";
+      CHECK_NE(CloseHandle(map_handle_map_[ptr]), 0)
+        << "Failed to CloseHandle shared memory ";
+      map_handle_map_.erase(ptr);
+      it = is_free_.erase(it);
+    } else {
+      ++it;
+    }
+  }
+}
+#endif  // _WIN32
 }  // namespace storage
 }  // namespace mxnet
 

-- 
To stop receiving notification emails like this one, please contact
['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].

Reply via email to