This is an automated email from the ASF dual-hosted git repository. haoj pushed a commit to branch numpy in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit c81c12904795e6bf96112b3e6452e1504e0c691e Author: Jake Lee <gstu1...@gmail.com> AuthorDate: Thu Jun 20 00:14:36 2019 -0700 Numpy compatible multinomial (#15219) * draft of multinomial * rename to more concise name * finish shape * complete the forward function * complete forward without handle 0 dimension & scalar * handle 0 dimension * add new line * fix lint * fix the build error * fix lint * finish unit test * change the registration * make multinomial support pvals as mx.ndarray * delete newline * fix lint error * support input as list, mx.ndarray, np.ndarray & unit test * fix lint * fix the include error * fix lint * refactor & pass the tensor instead of tuple to kernel * fix lint * updata the doc * address the comment --- python/mxnet/_numpy_op_doc.py | 30 ++++ python/mxnet/ndarray/numpy/random.py | 41 +++++- python/mxnet/numpy/random.py | 30 ++++ src/operator/numpy/random/np_multinomial_op.cc | 61 ++++++++ src/operator/numpy/random/np_multinomial_op.cu | 34 +++++ src/operator/numpy/random/np_multinomial_op.h | 193 +++++++++++++++++++++++++ tests/python/unittest/test_numpy_ndarray.py | 47 +++++- 7 files changed, 434 insertions(+), 2 deletions(-) diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 9265a98..ab81732 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -109,3 +109,33 @@ def _np_repeat(a, repeats, axis=None): the given axis. """ pass + + +def _npi_multinomial(a): + """Draw samples from a multinomial distribution. + + The multinomial distribution is a multivariate generalisation of the binomial distribution. + Take an experiment with one of ``p`` possible outcomes. An example of such an experiment is throwing a dice, + where the outcome can be 1 through 6. Each sample drawn from the distribution represents n such experiments. + Its values, ``X_i = [X_0, X_1, ..., X_p]``, represent the number of times the outcome was ``i``. + + + Parameters + ---------- + n : int + Number of experiments. + pvals : sequence of floats, length p + Probabilities of each of the p different outcomes. These should sum to 1 + (however, the last element is always assumed to account for the remaining + probability, as long as ``sum(pvals[:-1]) <= 1)``. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` sam- + ples are drawn. Default is None, in which case a single value is returned. + + Returns + ------- + out : ndarray + The drawn samples, of shape size, if that was provided. If not, the shape is ``(N,)``. + In other words, each entry ``out[i,j,...,:]`` is an N-dimensional value drawn from the distribution. + """ + pass diff --git a/python/mxnet/ndarray/numpy/random.py b/python/mxnet/ndarray/numpy/random.py index 3d9fd6a..8607fd5 100644 --- a/python/mxnet/ndarray/numpy/random.py +++ b/python/mxnet/ndarray/numpy/random.py @@ -17,11 +17,13 @@ """Namespace for operators used in Gluon dispatched by F=ndarray.""" from __future__ import absolute_import +import numpy as np from ...base import numeric_types from ...context import current_context +from ..ndarray import NDArray from . import _internal as _npi -__all__ = ['uniform', 'normal'] +__all__ = ['uniform', 'normal', 'multinomial'] def _random_helper(random, sampler, params, shape, dtype, ctx, out, kwargs): @@ -135,3 +137,40 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): out = kwargs.pop('out', None) return _random_helper(_npi.random_normal, None, [loc, scale], size, dtype, ctx, out, kwargs) + + +def multinomial(n, pvals, size=None): + """Draw samples from a multinomial distribution. + + The multinomial distribution is a multivariate generalisation of the binomial distribution. + Take an experiment with one of ``p`` possible outcomes. An example of such an experiment is throwing a dice, + where the outcome can be 1 through 6. Each sample drawn from the distribution represents n such experiments. + Its values, ``X_i = [X_0, X_1, ..., X_p]``, represent the number of times the outcome was ``i``. + + + Parameters + ---------- + n : int + Number of experiments. + pvals : sequence of floats, length p + Probabilities of each of the p different outcomes. These should sum to 1 + (however, the last element is always assumed to account for the remaining + probability, as long as ``sum(pvals[:-1]) <= 1)``. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` sam- + ples are drawn. Default is None, in which case a single value is returned. + + Returns + ------- + out : ndarray + The drawn samples, of shape size, if that was provided. If not, the shape is ``(N,)``. + In other words, each entry ``out[i,j,...,:]`` is an N-dimensional value drawn from the distribution. + """ + if isinstance(pvals, NDArray): + return _npi.multinomial(pvals, pvals=None, n=n, size=size) + else: + if isinstance(pvals, np.ndarray): + pvals = pvals.tolist() + if any(isinstance(i, list) for i in pvals): + raise ValueError('object too deep for desired array') + return _npi.multinomial(n=n, pvals=pvals, size=size) diff --git a/python/mxnet/numpy/random.py b/python/mxnet/numpy/random.py index baeab8b..cda1ada 100644 --- a/python/mxnet/numpy/random.py +++ b/python/mxnet/numpy/random.py @@ -98,3 +98,33 @@ def normal(loc=0.0, scale=1.0, size=None, **kwargs): This function currently does not support ``loc`` and ``scale`` as ndarrays. """ return _mx_nd_np.random.normal(loc, scale, size, **kwargs) + + +def multinomial(n, pvals, size=None, **kwargs): + """Draw samples from a multinomial distribution. + + The multinomial distribution is a multivariate generalisation of the binomial distribution. + Take an experiment with one of ``p`` possible outcomes. An example of such an experiment is throwing a dice, + where the outcome can be 1 through 6. Each sample drawn from the distribution represents n such experiments. + Its values, ``X_i = [X_0, X_1, ..., X_p]``, represent the number of times the outcome was ``i``. + + + Parameters + ---------- + n : int + Number of experiments. + pvals : sequence of floats, length p + Probabilities of each of the p different outcomes. These should sum to 1 + (however, the last element is always assumed to account for the remaining + probability, as long as ``sum(pvals[:-1]) <= 1)``. + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., ``(m, n, k)``, then ``m * n * k`` sam- + ples are drawn. Default is None, in which case a single value is returned. + + Returns + ------- + out : ndarray + The drawn samples, of shape size, if that was provided. If not, the shape is ``(N,)``. + In other words, each entry ``out[i,j,...,:]`` is an N-dimensional value drawn from the distribution. + """ + return _mx_nd_np.random.multinomial(n, pvals, size, **kwargs) diff --git a/src/operator/numpy/random/np_multinomial_op.cc b/src/operator/numpy/random/np_multinomial_op.cc new file mode 100644 index 0000000..bf4f88c --- /dev/null +++ b/src/operator/numpy/random/np_multinomial_op.cc @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_multinomial_op.h + * \brief Operator for numpy sampling from multinomial distributions + */ +#include "./np_multinomial_op.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(NumpyMultinomialParam); + +NNVM_REGISTER_OP(_npi_multinomial) +.describe(R"code(Draw samples from a multinomial distribution. " +"The multinomial distribution is a multivariate generalisation of the binomial distribution. " +"Take an experiment with one of p possible outcomes. " +"An example of such an experiment is throwing a dice, where the outcome can be 1 through 6. " +"Each sample drawn from the distribution represents n such experiments. " +"Its values, X_i = [X_0, X_1, ..., X_p], represent the number of times the outcome was i. +)code") +.set_num_inputs( + [](const nnvm::NodeAttrs& attrs) { + const NumpyMultinomialParam& param = nnvm::get<NumpyMultinomialParam>(attrs.parsed); + return param.pvals.has_value() ? 0U : 1U; + } +) +.set_num_outputs(1) +.set_attr_parser(ParamParser<NumpyMultinomialParam>) +.set_attr<mxnet::FInferShape>("FInferShape", NumpyMultinomialOpShape) +.set_attr<nnvm::FInferType>("FInferType", NumpyMultinomialOpType) +.set_attr<FResourceRequest>("FResourceRequest", + [](const nnvm::NodeAttrs& attrs) { + return std::vector<ResourceRequest>{ + ResourceRequest::kRandom, ResourceRequest::kTempSpace}; + }) +.set_attr<FCompute>("FCompute<cpu>", NumpyMultinomialForward<cpu>) +.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes) +.add_argument("a", "NDArray-or-Symbol", "Source input") +.add_arguments(NumpyMultinomialParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_multinomial_op.cu b/src/operator/numpy/random/np_multinomial_op.cu new file mode 100644 index 0000000..a809260 --- /dev/null +++ b/src/operator/numpy/random/np_multinomial_op.cu @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_multinomial_op.cu + * \brief Operator for numpy sampling from multinomial distributions + */ +#include "./np_multinomial_op.h" + +namespace mxnet { +namespace op { + +NNVM_REGISTER_OP(_npi_multinomial) +.set_attr<FCompute>("FCompute<gpu>", NumpyMultinomialForward<gpu>); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/random/np_multinomial_op.h b/src/operator/numpy/random/np_multinomial_op.h new file mode 100644 index 0000000..39515b4 --- /dev/null +++ b/src/operator/numpy/random/np_multinomial_op.h @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_multinomial_op.h + * \brief Operator for sampling from multinomial distributions + */ +#ifndef MXNET_OPERATOR_NUMPY_RANDOM_NP_MULTINOMIAL_OP_H_ +#define MXNET_OPERATOR_NUMPY_RANDOM_NP_MULTINOMIAL_OP_H_ + +#include <mxnet/operator_util.h> +#include <vector> +#include "../../mshadow_op.h" +#include "../../mxnet_op.h" +#include "../../operator_common.h" +#include "../../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +struct NumpyMultinomialParam : public dmlc::Parameter<NumpyMultinomialParam> { + int n; + dmlc::optional<mxnet::Tuple<double>> pvals; + dmlc::optional<mxnet::Tuple<int>> size; + DMLC_DECLARE_PARAMETER(NumpyMultinomialParam) { + DMLC_DECLARE_FIELD(n) + .describe("Number of experiments."); + DMLC_DECLARE_FIELD(pvals) + .set_default(dmlc::optional<mxnet::Tuple<double>>()) + .describe("Probabilities of each of the p different outcomes. " + "These should sum to 1 (however, the last element is always assumed to " + "account for the remaining probability, as long as sum(pvals[:-1]) <= 1)" + "Note that this is for internal usage only. " + "This operator will only have either input mx.ndarray or this list of pvals"); + DMLC_DECLARE_FIELD(size) + .set_default(dmlc::optional<mxnet::Tuple<int>>()) + .describe("Output shape. If the given shape is, " + "e.g., (m, n, k), then m * n * k samples are drawn. " + "Default is None, in which case a single value is returned."); + } +}; + +inline bool NumpyMultinomialOpShape(const nnvm::NodeAttrs& attrs, + std::vector<TShape> *in_attrs, + std::vector<TShape> *out_attrs) { + const NumpyMultinomialParam& param = nnvm::get<NumpyMultinomialParam>(attrs.parsed); + CHECK_EQ(out_attrs->size(), 1U); + + std::vector<dim_t> oshape_vec; + dim_t pvals_length; + if (param.pvals.has_value()) { + CHECK_EQ(in_attrs->size(), 0U); + pvals_length = param.pvals.value().ndim(); + } else { + // pvals is from input ndarray + CHECK_EQ(in_attrs->size(), 1U); + const TShape& ishape = (*in_attrs)[0]; + // check the input shape is only one dimension + CHECK_EQ(ishape.ndim(), 1U) + << "object too deep for desired array"; + pvals_length = ishape[0]; + } + if (param.size.has_value()) { + const mxnet::Tuple<int>& size = param.size.value(); + for (int i = 0; i < size.ndim(); ++i) { + oshape_vec.emplace_back(size[i]); + } + } + oshape_vec.emplace_back(pvals_length); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape(oshape_vec)); + return out_attrs->at(0).ndim() != 0U;; +} + +inline bool NumpyMultinomialOpType(const nnvm::NodeAttrs& attrs, + std::vector<int>* in_attrs, + std::vector<int>* out_attrs) { + const NumpyMultinomialParam& param = nnvm::get<NumpyMultinomialParam>(attrs.parsed); + CHECK_EQ(in_attrs->size(), (param.pvals.has_value()) ? 0U : 1U); + CHECK_EQ(out_attrs->size(), 1U); + + (*out_attrs)[0] = mshadow::kInt64; + return true; +} + +struct multinomial_kernel { + template<typename DType> + MSHADOW_XINLINE static void Map(int i, + const int num_exp, + const int prob_length, + DType* pvals, + float* uniform, + int64_t* out) { + for (int j = 0; j < num_exp; ++j) { + DType loc = static_cast<DType>(uniform[i * num_exp + j]); + DType acc = 0.0; + bool found = false; + for (int k = 0; k < prob_length; ++k) { + acc += pvals[k]; + if (acc > loc) { + found = true; + out[i * prob_length + k] += 1; + break; + } + } + if (!found) { + out[i * prob_length + (prob_length - 1)] += 1; + } + } + } +}; + +template<typename xpu> +void NumpyMultinomialForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector<TBlob>& inputs, + const std::vector<OpReqType>& req, + const std::vector<TBlob>& outputs) { + using namespace mshadow; + using namespace mxnet_op; + const NumpyMultinomialParam& param = nnvm::get<NumpyMultinomialParam>(attrs.parsed); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(inputs.size(), (param.pvals.has_value()) ? 0U : 1U); + + int prob_length = (param.pvals.has_value()) + ? param.pvals.value().ndim() : inputs[0].shape_[0]; + // if intput is [] or size contains 0 dimension + if (prob_length == 0U || outputs[0].shape_.Size() == 0) return; + int num_output = outputs[0].Size() / prob_length; + int num_exp = param.n; + Stream<xpu> *s = ctx.get_stream<xpu>(); + Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s); + Tensor<xpu, 1, float> uniform = + ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(num_output * param.n), s); + prnd->SampleUniform(&uniform, 0, 1); + + // set zero for the outputs + Kernel<set_zero, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<int64_t>()); + + if (param.pvals.has_value()) { + // create a tensor to copy the param.pvals tuple to avoid + // error: calling a __host__ function from a __host__ __device__ function is not allowed + Tensor<xpu, 1, double> pvals = + ctx.requested[1].get_space_typed<xpu, 1, double>(Shape1(prob_length), s); + double* pvals_ = pvals.dptr_; + // check if sum of input(pvals) > 1.0 + double sum = 0.0; + for (int i = 0; i < prob_length; ++i) { + sum += param.pvals.value()[i]; + // copy the tuple to data for later kernel usage + pvals_[i] = param.pvals.value()[i]; + CHECK_LE(sum, 1.0) + << "sum(pvals[:-1]) > 1.0"; + } + Kernel<multinomial_kernel, xpu>::Launch( + s, num_output, num_exp, prob_length, pvals_, uniform.dptr_, outputs[0].dptr<int64_t>()); + } else { + MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { + // check if sum of input(pvals) > 1.0 + DType sum = DType(0); + DType* input = inputs[0].dptr<DType>(); + for (int i = 0; i < prob_length; ++i) { + sum += input[i]; + CHECK_LE(sum, 1.0) + << "sum(pvals[:-1]) > 1.0"; + } + Kernel<multinomial_kernel, xpu>::Launch( + s, num_output, num_exp, prob_length, + inputs[0].dptr<DType>(), uniform.dptr_, outputs[0].dptr<int64_t>()); + }); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_RANDOM_NP_MULTINOMIAL_OP_H_ diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index 0d8eacf..e6e4911 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -23,7 +23,7 @@ import numpy as _np import mxnet as mx from mxnet import np, npx, autograd from mxnet.gluon import HybridBlock -from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, assert_exception +from mxnet.test_utils import same, assert_almost_equal, rand_shape_nd, rand_ndarray, retry, assert_exception from common import with_seed, TemporaryDirectory @@ -669,6 +669,51 @@ def test_np_save_load_ndarrays(): assert _np.array_equal(v.asnumpy(), arr_dict[k].asnumpy()) +@retry(5) +@with_seed() +@npx.use_np_shape +def test_np_multinomial(): + pvals_list = [[0.0, 0.1, 0.2, 0.3, 0.4], [0.4, 0.3, 0.2, 0.1, 0.0]] + sizes = [None, (), (3,), (2, 5, 7), (4, 9)] + experiements = 10000 + for pvals_type in [list, _np.ndarray]: + for have_size in [False, True]: + for pvals in pvals_list: + if have_size: + for size in sizes: + if pvals_type == mx.nd.NDArray: + pvals = mx.nd.array(pvals).as_np_ndarray() + elif pvals_type == _np.ndarray: + pvals = _np.array(pvals) + freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy() / _np.float32(experiements) + # for those cases that didn't need reshape + if size in [None, ()]: + mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, atol=1e-1) + else: + # check the shape + assert freq.shape == size + (len(pvals),), 'freq.shape={}, size + (len(pvals))={}'.format(freq.shape, size + (len(pvals))) + freq = freq.reshape((-1, len(pvals))) + # check the value for each row + for i in range(freq.shape[0]): + mx.test_utils.assert_almost_equal(freq[i, :], pvals, rtol=0.20, atol=1e-1) + else: + freq = mx.np.random.multinomial(experiements, pvals).asnumpy() / _np.float32(experiements) + mx.test_utils.assert_almost_equal(freq, pvals, rtol=0.20, atol=1e-1) + # check the zero dimension + sizes = [(0), (0, 2), (4, 0, 2), (3, 0, 1, 2, 0)] + for pvals in pvals_list: + for size in sizes: + freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy() + assert freq.size == 0 + # check [] as pvals + for pvals in [[], ()]: + freq = mx.np.random.multinomial(experiements, pvals).asnumpy() + assert freq.size == 0 + for size in sizes: + freq = mx.np.random.multinomial(experiements, pvals, size=size).asnumpy() + assert freq.size == 0 + + if __name__ == '__main__': import nose nose.runmodule()