This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 07aef13  Add unpooled gpu memory type (#14716)
07aef13 is described below

commit 07aef13e879cfdb7f361a245ef6c93f9b3926846
Author: vlado <vladoovtcha...@gmail.com>
AuthorDate: Mon Apr 29 20:24:52 2019 -0600

    Add unpooled gpu memory type (#14716)
    
    * Add unpooled gpu memory type
    
    * Adding missing header
    
    * undo bad rebase change
---
 docs/faq/env_var.md    |  6 +++++-
 src/storage/storage.cc | 10 ++++++----
 2 files changed, 11 insertions(+), 5 deletions(-)

diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md
index 095c214..c5ebd54 100644
--- a/docs/faq/env_var.md
+++ b/docs/faq/env_var.md
@@ -80,16 +80,20 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
 * MXNET_GPU_MEM_POOL_RESERVE
   - Values: Int ```(default=5)```
   - The percentage of GPU memory to reserve for things other than the GPU 
array, such as kernel launch or cudnn handle space.
-  - If you see a strange out-of-memory error from the kernel launch, after 
multiple iterations, try setting this to a larger value.  
+  - If you see a strange out-of-memory error from the kernel launch, after 
multiple iterations, try setting this to a larger value.
+
 * MXNET_GPU_MEM_POOL_TYPE
   - Values: String ```(default=Naive)```
   - The type of memory pool.
   - Choices:
     - Naive: A simple memory pool that allocates memory for the exact 
requested size and cache memory buffers. If a buffered memory chunk matches the 
size of a new request, the chunk from the memory pool will be returned and 
reused.
     - Round: A memory pool that always rounds the requested memory size and 
allocates memory of the rounded size. MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF 
defines how to round up a memory size. Caching and allocating buffered memory 
works in the same way as the naive memory pool.
+    - Unpooled: No memory pool is used.
+
 * MXNET_GPU_MEM_POOL_ROUND_LINEAR_CUTOFF
   - Values: Int ```(default=24)```
   - The cutoff threshold that decides the rounding strategy. Let's denote the 
threshold as T. If the memory size is smaller than `2 ** T` (by default, it's 2 
** 24 = 16MB), it rounds to the smallest `2 ** n` that is larger than the 
requested memory size; if the memory size is larger than `2 ** T`, it rounds to 
the next k * 2 ** T.
+
 * MXNET_GPU_MEM_LARGE_ALLOC_ROUND_SIZE
   - Values: Int ```(default=2097152)```
   - When using the naive pool type, memory allocations larger than this 
threshhold are rounded up to a multiple of this value.
diff --git a/src/storage/storage.cc b/src/storage/storage.cc
index 4f15351..0ca5ef7 100644
--- a/src/storage/storage.cc
+++ b/src/storage/storage.cc
@@ -26,6 +26,7 @@
 #include "./pooled_storage_manager.h"
 #include "./cpu_shared_storage_manager.h"
 #include "./cpu_device_storage.h"
+#include "./gpu_device_storage.h"
 #include "./pinned_memory_storage.h"
 #include "../common/lazy_alloc_array.h"
 #include "../profiler/storage_profiler.h"
@@ -106,11 +107,12 @@ void StorageImpl::Alloc(Storage::Handle* handle) {
             if (strategy == "Round") {
               ptr = new storage::GPUPooledRoundedStorageManager(handle->ctx);
               LOG(INFO) << "Using GPUPooledRoundedStorageManager.";
-            } else {
-              if (strategy != "Naive") {
-                LOG(FATAL) << "Unknown memory pool strategy specified: " << 
strategy << ".";
-              }
+            } else if (strategy == "Naive") {
               ptr = new storage::GPUPooledStorageManager(handle->ctx);
+            } else if (strategy == "Unpooled") {
+              ptr = new 
storage::NaiveStorageManager<storage::GPUDeviceStorage>();
+            } else {
+              LOG(FATAL) << "Unknown memory pool strategy specified: " << 
strategy << ".";
             }
 #else
             LOG(FATAL) << "Compile with USE_CUDA=1 to enable GPU usage";

Reply via email to