This is an automated email from the ASF dual-hosted git repository. reminisce pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 31144c7 Fix (#17674) 31144c7 is described below commit 31144c763bfd0fe199b7fe0f23a20555c9731e7a Author: reminisce <wujun....@gmail.com> AuthorDate: Mon Feb 24 19:58:25 2020 -0800 Fix (#17674) --- src/nnvm/plan_memory.cc | 25 +++++++++++++------------ tests/python/unittest/test_numpy_gluon.py | 21 +++++++++++++++++++++ 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/src/nnvm/plan_memory.cc b/src/nnvm/plan_memory.cc index 6c6e02d..3815f23 100644 --- a/src/nnvm/plan_memory.cc +++ b/src/nnvm/plan_memory.cc @@ -38,21 +38,22 @@ namespace { // Return bytes of data flag. static int MXGetDTypeSize(int type_flag) { switch (type_flag) { - case kUint8: - case kInt8: + case mshadow::kUint8: + case mshadow::kInt8: + case mshadow::kBool: return 1; - case kFloat16: - case kBfloat16: - case kInt16: - case kUint16: + case mshadow::kFloat16: + case mshadow::kBfloat16: + case mshadow::kInt16: + case mshadow::kUint16: return 2; - case kFloat32: - case kInt32: - case kUint32: + case mshadow::kFloat32: + case mshadow::kInt32: + case mshadow::kUint32: return 4; - case kFloat64: - case kInt64: - case kUint64: + case mshadow::kFloat64: + case mshadow::kInt64: + case mshadow::kUint64: return 8; default: LOG(FATAL) << "unknown type_flag=" << type_flag; diff --git a/tests/python/unittest/test_numpy_gluon.py b/tests/python/unittest/test_numpy_gluon.py index 6ce9e18..0d1e5fe 100644 --- a/tests/python/unittest/test_numpy_gluon.py +++ b/tests/python/unittest/test_numpy_gluon.py @@ -400,6 +400,27 @@ def test_net_symbol_save_load(): mx.np.random.normal(0, 1, (10, 5, 8))]) +@with_seed() +@use_np +def test_hybridize_boolean_dtype(): + class Foo(gluon.HybridBlock): + def __init__(self, prefix=None, params=None): + super(Foo, self).__init__(prefix=prefix, params=params) + + def hybrid_forward(self, F, valid_length): + mask = ((F.np.ones((10,)) / 2) < valid_length) + return mask + + valid_length = mx.np.random.uniform(size=(10,)) + foo = Foo() + out1 = foo(valid_length) + + foo = Foo() + foo.hybridize() + out2 = foo(valid_length) + + assert mx.test_utils.same(out1.asnumpy(), out2.asnumpy()) + if __name__ == '__main__': import nose