This is an automated email from the ASF dual-hosted git repository.

taolv pushed a commit to branch v1.7.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.7.x by this push:
     new c4d9270  Fix memory leaks in Gluon (#18328) (#18358)
c4d9270 is described below

commit c4d9270dde5c091386dbdbd53f8060a73b98cbc9
Author: Leonard Lausen <lau...@amazon.com>
AuthorDate: Mon May 18 18:51:45 2020 -0700

    Fix memory leaks in Gluon (#18328) (#18358)
    
    Fix leak of ndarray objects in the frontend due to reference cycle.
    
    Backport of 3e676fc2c88bec75e4463c8fa9b5532664d518c2
---
 python/mxnet/gluon/block.py                | 21 ++++++++++------
 tests/python/unittest/test_gluon.py        | 39 ++++++++++++++++++++++++++++++
 tests/python/unittest/test_thread_local.py |  5 ++--
 3 files changed, 55 insertions(+), 10 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index bed6679..968c787 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -23,8 +23,10 @@ __all__ = ['Block', 'HybridBlock', 'SymbolBlock']
 import threading
 import copy
 import warnings
-import re
+import weakref
 from collections import OrderedDict, defaultdict
+
+import re
 import numpy as np
 
 from ..base import mx_real_t, MXNetError
@@ -46,7 +48,7 @@ class _BlockScope(object):
     _current = threading.local()
 
     def __init__(self, block):
-        self._block = block
+        self._block = weakref.ref(block) if block is not None else None
         self._counter = {}
         self._old_scope = None
         self._name_scope = None
@@ -55,7 +57,8 @@ class _BlockScope(object):
     def create(prefix, params, hint):
         """Creates prefix and params for new `Block`."""
         current = getattr(_BlockScope._current, "value", None)
-        if current is None:
+        block = current._block() if current is not None else None
+        if current is None or block is None:
             if prefix is None:
                 if not hasattr(_name.NameManager._current, "value"):
                     _name.NameManager._current.value = _name.NameManager()
@@ -71,23 +74,25 @@ class _BlockScope(object):
             prefix = '%s%d_'%(hint, count)
             current._counter[hint] = count + 1
         if params is None:
-            parent = current._block.params
+            parent = block.params
             params = ParameterDict(parent.prefix+prefix, parent._shared)
         else:
             params = ParameterDict(params.prefix, params)
-        return current._block.prefix+prefix, params
+        return block.prefix + prefix, params
 
     def __enter__(self):
-        if self._block._empty_prefix:
+        block = self._block()
+        if block is None or block._empty_prefix:
             return self
         self._old_scope = getattr(_BlockScope._current, "value", None)
         _BlockScope._current.value = self
-        self._name_scope = _name.Prefix(self._block.prefix)
+        self._name_scope = _name.Prefix(block.prefix)
         self._name_scope.__enter__()
         return self
 
     def __exit__(self, ptype, value, trace):
-        if self._block._empty_prefix:
+        block = self._block()
+        if block is None or block._empty_prefix:
             return
         self._name_scope.__exit__(ptype, value, trace)
         self._name_scope = None
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index a026825..cf6bc36 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -17,6 +17,7 @@
 
 import os
 import tempfile
+import gc
 
 import mxnet as mx
 from mxnet import gluon
@@ -3212,6 +3213,44 @@ def test_reqs_switching_training_inference():
 
     mx.test_utils.assert_almost_equal(grad1, grad2)
 
+def test_no_memory_leak_in_gluon():
+    # Collect all other garbage prior to this test. Otherwise the test may fail
+    # due to unrelated memory leaks.
+    gc.collect()
+
+    gc_flags = gc.get_debug()
+    gc.set_debug(gc.DEBUG_SAVEALL)
+    net = mx.gluon.nn.Dense(10, in_units=10)
+    net.initialize()
+    del net
+    gc.collect()
+    gc.set_debug(gc_flags)  # reset gc flags
+
+    # Check for leaked NDArrays
+    seen = set()
+    def has_array(element):
+        try:
+            if element in seen:
+                return False
+            seen.add(element)
+        except TypeError:  # unhashable
+            pass
+
+        if isinstance(element, mx.nd._internal.NDArrayBase):
+            return True
+        elif hasattr(element, '__dict__'):
+            return any(has_array(x) for x in vars(element))
+        elif isinstance(element, dict):
+            return any(has_array(x) for x in element.items())
+        else:
+            try:
+                return any(has_array(x) for x in element)
+            except (TypeError, KeyError):
+                return False
+
+    assert not any(has_array(x) for x in gc.garbage), 'Found leaked NDArrays 
due to reference cycles'
+    del gc.garbage[:]
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
diff --git a/tests/python/unittest/test_thread_local.py 
b/tests/python/unittest/test_thread_local.py
index f0e3c66..50ecb06 100644
--- a/tests/python/unittest/test_thread_local.py
+++ b/tests/python/unittest/test_thread_local.py
@@ -124,8 +124,9 @@ def test_blockscope():
     status = [False]
     event = threading.Event()
     def f():
-        with block._BlockScope(dummy_block("spawned_")):
-            x= NameManager.current.get(None, "hello")
+        net = dummy_block("spawned_")  # BlockScope only keeps a weakref to 
the Block
+        with block._BlockScope(net):
+            x = NameManager.current.get(None, "hello")
             event.wait()
             if x == "spawned_hello0":
                 status[0] = True

Reply via email to