This is an automated email from the ASF dual-hosted git repository.

haoj pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 24a14b3d46bf95f1ab1deac1767133893cc61523
Author: Jake Lee <gstu1...@gmail.com>
AuthorDate: Thu Jul 11 20:51:32 2019 -0700

    fix memory override bug in multinomial (#15397)
---
 src/operator/numpy/random/np_multinomial_op.h | 22 +++++++++++-----------
 tests/python/unittest/test_numpy_ndarray.py   | 10 ++++++++++
 2 files changed, 21 insertions(+), 11 deletions(-)

diff --git a/src/operator/numpy/random/np_multinomial_op.h 
b/src/operator/numpy/random/np_multinomial_op.h
index 39515b4..7115f27 100644
--- a/src/operator/numpy/random/np_multinomial_op.h
+++ b/src/operator/numpy/random/np_multinomial_op.h
@@ -105,7 +105,7 @@ struct multinomial_kernel {
                                   const int num_exp,
                                   const int prob_length,
                                   DType* pvals,
-                                  float* uniform,
+                                  double* uniform,
                                   int64_t* out) {
     for (int j = 0; j < num_exp; ++j) {
       DType loc = static_cast<DType>(uniform[i * num_exp + j]);
@@ -145,20 +145,20 @@ void NumpyMultinomialForward(const nnvm::NodeAttrs& attrs,
   int num_output = outputs[0].Size() / prob_length;
   int num_exp = param.n;
   Stream<xpu> *s = ctx.get_stream<xpu>();
-  Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
-  Tensor<xpu, 1, float> uniform =
-      ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(num_output * 
param.n), s);
-  prnd->SampleUniform(&uniform, 0, 1);
+  Random<xpu, double> *prnd = ctx.requested[0].get_random<xpu, double>(s);
+  size_t temp_space_ = (param.pvals.has_value())
+                      ? num_output * param.n + prob_length : num_output * 
param.n;
+  Tensor<xpu, 1, double> temp_tensor =
+      ctx.requested[1].get_space_typed<xpu, 1, double>(Shape1(temp_space_), s);
 
+  prnd->SampleUniform(&temp_tensor, 0, 1);
   // set zero for the outputs
   Kernel<set_zero, xpu>::Launch(s, outputs[0].Size(), 
outputs[0].dptr<int64_t>());
-
   if (param.pvals.has_value()) {
     // create a tensor to copy the param.pvals tuple to avoid
     // error: calling a __host__ function from a __host__ __device__ function 
is not allowed
-    Tensor<xpu, 1, double> pvals =
-      ctx.requested[1].get_space_typed<xpu, 1, double>(Shape1(prob_length), s);
-    double* pvals_ = pvals.dptr_;
+    // reuse the uniform temp space to create pval tensor
+    double* pvals_ = temp_tensor.dptr_ + num_output * param.n;
     // check if sum of input(pvals) > 1.0
     double sum = 0.0;
     for (int i = 0; i < prob_length; ++i) {
@@ -169,7 +169,7 @@ void NumpyMultinomialForward(const nnvm::NodeAttrs& attrs,
           << "sum(pvals[:-1]) > 1.0";
     }
     Kernel<multinomial_kernel, xpu>::Launch(
-      s, num_output, num_exp, prob_length, pvals_, uniform.dptr_, 
outputs[0].dptr<int64_t>());
+      s, num_output, num_exp, prob_length, pvals_, temp_tensor.dptr_, 
outputs[0].dptr<int64_t>());
   } else {
     MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
       // check if sum of input(pvals) > 1.0
@@ -182,7 +182,7 @@ void NumpyMultinomialForward(const nnvm::NodeAttrs& attrs,
       }
       Kernel<multinomial_kernel, xpu>::Launch(
         s, num_output, num_exp, prob_length,
-        inputs[0].dptr<DType>(), uniform.dptr_, outputs[0].dptr<int64_t>());
+        inputs[0].dptr<DType>(), temp_tensor.dptr_, 
outputs[0].dptr<int64_t>());
     });
   }
 }
diff --git a/tests/python/unittest/test_numpy_ndarray.py 
b/tests/python/unittest/test_numpy_ndarray.py
index c5a9279..887bb9a 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -712,6 +712,16 @@ def test_np_multinomial():
         for size in sizes:
             freq = mx.np.random.multinomial(experiements, pvals, 
size=size).asnumpy()
             assert freq.size == 0
+    # test small experiment for github issue
+    # https://github.com/apache/incubator-mxnet/issues/15383
+    small_exp, total_exp = 20, 10000
+    for pvals in pvals_list:
+        x = np.random.multinomial(small_exp, pvals)
+        for i in range(total_exp // small_exp):
+            x = x + np.random.multinomial(20, pvals)
+    freq = (x.asnumpy() / _np.float32(total_exp)).reshape((-1, len(pvals)))
+    for i in range(freq.shape[0]):
+        mx.test_utils.assert_almost_equal(freq[i, :], pvals, rtol=0.20, 
atol=1e-1)
 
 
 if __name__ == '__main__':

Reply via email to