This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 31144c7  Fix (#17674)
31144c7 is described below

commit 31144c763bfd0fe199b7fe0f23a20555c9731e7a
Author: reminisce <wujun....@gmail.com>
AuthorDate: Mon Feb 24 19:58:25 2020 -0800

    Fix (#17674)
---
 src/nnvm/plan_memory.cc                   | 25 +++++++++++++------------
 tests/python/unittest/test_numpy_gluon.py | 21 +++++++++++++++++++++
 2 files changed, 34 insertions(+), 12 deletions(-)

diff --git a/src/nnvm/plan_memory.cc b/src/nnvm/plan_memory.cc
index 6c6e02d..3815f23 100644
--- a/src/nnvm/plan_memory.cc
+++ b/src/nnvm/plan_memory.cc
@@ -38,21 +38,22 @@ namespace {
 // Return bytes of data flag.
 static int MXGetDTypeSize(int type_flag) {
   switch (type_flag) {
-    case kUint8:
-    case kInt8:
+    case mshadow::kUint8:
+    case mshadow::kInt8:
+    case mshadow::kBool:
       return 1;
-    case kFloat16:
-    case kBfloat16:
-    case kInt16:
-    case kUint16:
+    case mshadow::kFloat16:
+    case mshadow::kBfloat16:
+    case mshadow::kInt16:
+    case mshadow::kUint16:
       return 2;
-    case kFloat32:
-    case kInt32:
-    case kUint32:
+    case mshadow::kFloat32:
+    case mshadow::kInt32:
+    case mshadow::kUint32:
       return 4;
-    case kFloat64:
-    case kInt64:
-    case kUint64:
+    case mshadow::kFloat64:
+    case mshadow::kInt64:
+    case mshadow::kUint64:
       return 8;
     default:
       LOG(FATAL) << "unknown type_flag=" << type_flag;
diff --git a/tests/python/unittest/test_numpy_gluon.py 
b/tests/python/unittest/test_numpy_gluon.py
index 6ce9e18..0d1e5fe 100644
--- a/tests/python/unittest/test_numpy_gluon.py
+++ b/tests/python/unittest/test_numpy_gluon.py
@@ -400,6 +400,27 @@ def test_net_symbol_save_load():
                                   mx.np.random.normal(0, 1, (10, 5, 8))])
 
 
+@with_seed()
+@use_np
+def test_hybridize_boolean_dtype():
+    class Foo(gluon.HybridBlock):
+        def __init__(self, prefix=None, params=None):
+            super(Foo, self).__init__(prefix=prefix, params=params)
+
+        def hybrid_forward(self, F, valid_length):
+            mask = ((F.np.ones((10,)) / 2) < valid_length)
+            return mask
+
+    valid_length = mx.np.random.uniform(size=(10,))
+    foo = Foo()
+    out1 = foo(valid_length)
+
+    foo = Foo()
+    foo.hybridize()
+    out2 = foo(valid_length)
+
+    assert mx.test_utils.same(out1.asnumpy(), out2.asnumpy())
+
 
 if __name__ == '__main__':
     import nose

Reply via email to