This is an automated email from the ASF dual-hosted git repository. marcoabreu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 22e5ae3 add type switch to weight tensor (#16543) 22e5ae3 is described below commit 22e5ae39d0be39b9f280e89baeaf002c3572bd83 Author: Xi Wang <xid...@gmail.com> AuthorDate: Mon Oct 28 03:37:55 2019 +0800 add type switch to weight tensor (#16543) --- src/operator/numpy/random/np_choice_op.h | 20 +++++++++++++------- tests/python/unittest/test_numpy_op.py | 21 +++++++++++---------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/src/operator/numpy/random/np_choice_op.h b/src/operator/numpy/random/np_choice_op.h index 335cc27..a6a7cec 100644 --- a/src/operator/numpy/random/np_choice_op.h +++ b/src/operator/numpy/random/np_choice_op.h @@ -118,15 +118,17 @@ struct random_indices { // Weighted sample without replacement. // Use perturbed Gumbel variates as keys. +template <typename IType> struct generate_keys { - MSHADOW_XINLINE static void Map(index_t i, float *uniforms, float *weights) { + MSHADOW_XINLINE static void Map(index_t i, float *uniforms, IType *weights) { uniforms[i] = -logf(-logf(uniforms[i])) + logf(weights[i]); } }; // Weighted sample with replacement. +template <typename IType> struct categorical_sampling { - MSHADOW_XINLINE static void Map(index_t i, float *weights, size_t length, + MSHADOW_XINLINE static void Map(index_t i, IType *weights, size_t length, float *uniforms, int64_t *outs) { outs[i] = 0; float acc = 0.0; @@ -179,15 +181,19 @@ void NumpyChoiceForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, prnd->SampleUniform(&random_numbers, 0, 1); workspace_ptr += ((random_tensor_size * sizeof(float) / 7 + 1) * 8); if (replace) { - Kernel<categorical_sampling, xpu>::Launch( - s, output_size, inputs[weight_index].dptr<float>(), input_size, - random_numbers.dptr_, outputs[0].dptr<int64_t>()); + MSHADOW_REAL_TYPE_SWITCH(inputs[weight_index].type_flag_, IType, { + Kernel<categorical_sampling<IType>, xpu>::Launch( + s, output_size, inputs[weight_index].dptr<IType>(), input_size, + random_numbers.dptr_, outputs[0].dptr<int64_t>()); + }); } else { Tensor<xpu, 1, int64_t> indices = Tensor<xpu, 1, int64_t>( reinterpret_cast<int64_t *>(workspace_ptr), Shape1(indices_size), s); indices = expr::range((int64_t)0, input_size); - Kernel<generate_keys, xpu>::Launch(s, input_size, random_numbers.dptr_, - inputs[weight_index].dptr<float>()); + MSHADOW_REAL_TYPE_SWITCH(inputs[weight_index].type_flag_, IType, { + Kernel<generate_keys<IType>, xpu>::Launch(s, input_size, random_numbers.dptr_, + inputs[weight_index].dptr<IType>()); + }); _sort<xpu>(random_numbers.dptr_, indices.dptr_, input_size); Copy(outputs[0].FlatTo1D<xpu, int64_t>(s), indices.Slice(0, output_size), s); } diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 98a7b05..0177809 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2490,16 +2490,17 @@ def test_np_choice(): # test_sample_without_replacement(np.random.choice, num_classes, shape, 10 ** 5, weight) # Test hypridize mode: - for hybridize in [True, False]: - for replace in [True, False]: - test_choice = TestUniformChoice(num_classes // 2, replace) - test_choice_weighted = TestWeightedChoice(num_classes // 2, replace) - if hybridize: - test_choice.hybridize() - test_choice_weighted.hybridize() - weight = np.array(_np.random.dirichlet([1.0] * num_classes)) - test_indexing_mode(test_choice, num_classes, num_classes // 2, replace, None) - test_indexing_mode(test_choice_weighted, num_classes, num_classes // 2, replace, weight) + for wtype in ['float16', 'float32', 'float64']: + for hybridize in [True, False]: + for replace in [True, False]: + test_choice = TestUniformChoice(num_classes // 2, replace) + test_choice_weighted = TestWeightedChoice(num_classes // 2, replace) + if hybridize: + test_choice.hybridize() + test_choice_weighted.hybridize() + weight = np.array(_np.random.dirichlet([1.0] * num_classes)).astype(wtype) + test_indexing_mode(test_choice, num_classes, num_classes // 2, replace, None) + test_indexing_mode(test_choice_weighted, num_classes, num_classes // 2, replace, weight) @with_seed()