This is an automated email from the ASF dual-hosted git repository. apeforest pushed a commit to branch v1.5.x in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.5.x by this push: new f9dbd0e Fix crash in random.shuffle operator (#15041) f9dbd0e is described below commit f9dbd0e05fb25ff6c773e0b587a92766e60bedf4 Author: Lin Yuan <apefor...@gmail.com> AuthorDate: Thu May 23 13:02:15 2019 -0700 Fix crash in random.shuffle operator (#15041) * fix crash in random_shuffle caused by int overflow * add unit test * add comment * remove small random test to avoid CI failure --- src/operator/random/shuffle_op.cc | 9 +++++++-- tests/python/unittest/test_random.py | 2 ++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/operator/random/shuffle_op.cc b/src/operator/random/shuffle_op.cc index 1bd70b1..345a771 100644 --- a/src/operator/random/shuffle_op.cc +++ b/src/operator/random/shuffle_op.cc @@ -45,8 +45,13 @@ namespace { template<typename DType, typename Rand> void Shuffle1D(DType* const out, const index_t size, Rand* const prnd) { #ifdef USE_GNU_PARALLEL_SHUFFLE - auto rand_n = [prnd](index_t n) { - std::uniform_int_distribution<index_t> dist(0, n - 1); + /* + * See issue #15029: the data type of n needs to be compatible with + * the gcc library: https://github.com/gcc-mirror/gcc/blob/master/libstdc%2B%2B\ + * -v3/include/parallel/random_shuffle.h#L384 + */ + auto rand_n = [prnd](uint32_t n) { + std::uniform_int_distribution<uint32_t> dist(0, n - 1); return dist(*prnd); }; __gnu_parallel::random_shuffle(out, out + size, rand_n); diff --git a/tests/python/unittest/test_random.py b/tests/python/unittest/test_random.py index 5e809d3..4d14719 100644 --- a/tests/python/unittest/test_random.py +++ b/tests/python/unittest/test_random.py @@ -867,6 +867,8 @@ def test_shuffle(): # Test larger arrays testLarge(mx.nd.arange(0, 100000).reshape((10, 10000)), 10) testLarge(mx.nd.arange(0, 100000).reshape((10000, 10)), 10) + testLarge(mx.nd.arange(0, 100000), 10) + @with_seed() def test_randint():