This is an automated email from the ASF dual-hosted git repository.

lanking pushed a commit to branch v1.5.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.5.x by this push:
     new 3b2bd99  fix lstm layer with projection save params (#17266) (#17286)
3b2bd99 is described below

commit 3b2bd999ea5293dfacd5008e459ea3c7271b9faa
Author: Frank Liu <frankfliu2...@gmail.com>
AuthorDate: Mon Feb 3 13:01:50 2020 -0800

    fix lstm layer with projection save params (#17266) (#17286)
    
    Co-authored-by: Sheng Zha <s...@users.noreply.github.com>
---
 python/mxnet/gluon/rnn/rnn_layer.py | 2 +-
 tests/python/gpu/test_gluon_gpu.py  | 2 ++
 2 files changed, 3 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/gluon/rnn/rnn_layer.py 
b/python/mxnet/gluon/rnn/rnn_layer.py
index b3cc596..11d4581 100644
--- a/python/mxnet/gluon/rnn/rnn_layer.py
+++ b/python/mxnet/gluon/rnn/rnn_layer.py
@@ -124,7 +124,7 @@ class _RNNLayer(HybridBlock):
     def _collect_params_with_prefix(self, prefix=''):
         if prefix:
             prefix += '.'
-        pattern = re.compile(r'(l|r)(\d)_(i2h|h2h)_(weight|bias)\Z')
+        pattern = re.compile(r'(l|r)(\d)_(i2h|h2h|h2r)_(weight|bias)\Z')
         def convert_key(m, bidirectional): # for compatibility with old 
parameter format
             d, l, g, t = [m.group(i) for i in range(1, 5)]
             if bidirectional:
diff --git a/tests/python/gpu/test_gluon_gpu.py 
b/tests/python/gpu/test_gluon_gpu.py
index fc65029..d6070d6 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -137,6 +137,8 @@ def test_lstmp():
     check_rnn_layer_forward(gluon.rnn.LSTM(10, 2, bidirectional=True, 
dropout=0.5, projection_size=5),
                             mx.nd.ones((8, 3, 20)),
                             [mx.nd.ones((4, 3, 5)), mx.nd.ones((4, 3, 10))], 
run_only=True, ctx=ctx)
+    lstm_layer.save_parameters('gpu_tmp.params')
+    lstm_layer.load_parameters('gpu_tmp.params')
 
 
 @with_seed()

Reply via email to